From 82453bab7ee4beaee0be2d0826ab557af11b4441 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 5 May 2025 03:22:02 -0400 Subject: [PATCH] handle xformers patch for inference too --- src/axolotl/monkeypatch/attention/xformers.py | 32 +++++++++++-------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 28eab6ec1..8e7eba42e 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -65,6 +65,7 @@ def xformers_attention_forward( key = key.reshape(-1, key.size(-2), key.size(-1)) value = value.reshape(-1, value.size(-2), value.size(-1)) + # pylint: disable=duplicate-code if module.config.num_attention_heads != module.config.num_key_value_heads: key = key.repeat_interleave( module.config.num_attention_heads // module.config.num_key_value_heads, @@ -75,20 +76,12 @@ def xformers_attention_forward( dim=2, ) - attn_output = xformers_attention( - query, - key, - value, - attn_bias=attn_bias, - ) - - else: - attn_output = xformers_attention( - query, - key, - value, - attn_bias=attn_bias, - ) + attn_output = xformers_attention( + query, + key, + value, + attn_bias=attn_bias, + ) elif attention_mask is not None: batch_size = query.shape[0] query, key, value, indices_q, cu_seq_lens, _ = _upad_input( @@ -111,6 +104,17 @@ def xformers_attention_forward( ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) else: + # pylint: disable=duplicate-code + if module.config.num_attention_heads != module.config.num_key_value_heads: + key = key.repeat_interleave( + module.config.num_attention_heads // module.config.num_key_value_heads, + dim=2, + ) + value = value.repeat_interleave( + module.config.num_attention_heads // module.config.num_key_value_heads, + dim=2, + ) + attn_output = xformers_attention( query, key,