diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 4d3d6b68e..c1ef25d5a 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -39,6 +39,9 @@ def xformers_forward( # pylint: disable=duplicate-code bsz, q_len, _ = hidden_states.size() + if not hasattr(self, 'pretraining_tp'): + self.pretraining_tp = 1 + if self.pretraining_tp > 1: key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)