diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index c1ef25d5a..5c15eea5e 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -156,7 +156,9 @@ def xformers_forward( f" {attn_output.size()}" ) - attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.transpose(1, 2).contiguous() + #end x-formers vs. not x-formers if-else block + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) if self.pretraining_tp > 1: