From 4f478083e779dca566d655c5f7b7f5f7e25ad20b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 5 May 2025 03:49:01 -0400 Subject: [PATCH] fix batch size setter --- src/axolotl/monkeypatch/attention/xformers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/axolotl/monkeypatch/attention/xformers.py b/src/axolotl/monkeypatch/attention/xformers.py index 8e7eba42e..394169eef 100644 --- a/src/axolotl/monkeypatch/attention/xformers.py +++ b/src/axolotl/monkeypatch/attention/xformers.py @@ -39,6 +39,7 @@ def xformers_attention_forward( key = key.transpose(1, 2) value = value.transpose(1, 2) query_length = query.shape[2] + batch_size = query.size(0) attn_bias = xformers.ops.LowerTriangularMask() @@ -49,8 +50,6 @@ def xformers_attention_forward( max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()) ): - batch_size = query.size(0) - if cu_seq_lens_q is None or cu_seq_lens_k is None: cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0] cu_seq_lens_q = cu_seq_lens_q.squeeze() @@ -83,7 +82,6 @@ def xformers_attention_forward( attn_bias=attn_bias, ) elif attention_mask is not None: - batch_size = query.shape[0] query, key, value, indices_q, cu_seq_lens, _ = _upad_input( query, key, value, attention_mask, query_length )