fix batch size setter
This commit is contained in:
@@ -39,6 +39,7 @@ def xformers_attention_forward(
|
|||||||
key = key.transpose(1, 2)
|
key = key.transpose(1, 2)
|
||||||
value = value.transpose(1, 2)
|
value = value.transpose(1, 2)
|
||||||
query_length = query.shape[2]
|
query_length = query.shape[2]
|
||||||
|
batch_size = query.size(0)
|
||||||
|
|
||||||
attn_bias = xformers.ops.LowerTriangularMask()
|
attn_bias = xformers.ops.LowerTriangularMask()
|
||||||
|
|
||||||
@@ -49,8 +50,6 @@ def xformers_attention_forward(
|
|||||||
max_length_q is not None
|
max_length_q is not None
|
||||||
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||||
):
|
):
|
||||||
batch_size = query.size(0)
|
|
||||||
|
|
||||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||||
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
||||||
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
||||||
@@ -83,7 +82,6 @@ def xformers_attention_forward(
|
|||||||
attn_bias=attn_bias,
|
attn_bias=attn_bias,
|
||||||
)
|
)
|
||||||
elif attention_mask is not None:
|
elif attention_mask is not None:
|
||||||
batch_size = query.shape[0]
|
|
||||||
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
query, key, value, indices_q, cu_seq_lens, _ = _upad_input(
|
||||||
query, key, value, attention_mask, query_length
|
query, key, value, attention_mask, query_length
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user