diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 8e7eba42e..394169eef 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -39,6 +39,7 @@ def xformers_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) query_length = query.shape[2] + batch_size = query.size(0) attn_bias = xformers.ops.LowerTriangularMask() @@ -49,8 +50,6 @@ def xformers_attention_forward( max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) ): - batch_size = query.size(0) - if cu_seq_lens_q is None or cu_seq_lens_k is None: cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0] cu_seq_lens_q = cu_seq_lens_q.squeeze() @@ -83,7 +82,6 @@ def xformers_attention_forward( attn_bias=attn_bias, ) elif attention_mask is not None: - batch_size = query.shape[0] query, key, value, indices_q, cu_seq_lens, _ = _upad_input( query, key, value, attention_mask, query_length )