Remove seq_len arg in rotary_emb (#1443)
* remove seq_len in llama rotary_emb * chore: lint --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
@@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn(
|
|||||||
# [bsz, nh, q_len, hd]
|
# [bsz, nh, q_len, hd]
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
cos, sin = self.rotary_emb(
|
|
||||||
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
|
||||||
)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
@@ -435,13 +430,7 @@ def flashattn_forward(
|
|||||||
# [bsz, q_len, nh, hd]
|
# [bsz, q_len, nh, hd]
|
||||||
# [bsz, nh, q_len, hd]
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
cos, sin = self.rotary_emb(value_states, position_ids=position_ids)
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(
|
|
||||||
value_states, seq_len=kv_seq_len, position_ids=position_ids
|
|
||||||
)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -80,11 +80,7 @@ def xformers_forward(
|
|||||||
# [bsz, q_len, nh, hd]
|
# [bsz, q_len, nh, hd]
|
||||||
# [bsz, nh, q_len, hd]
|
# [bsz, nh, q_len, hd]
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
cos, sin = self.rotary_emb(value_states)
|
||||||
if past_key_value is not None:
|
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
|
||||||
|
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
|
||||||
query_states, key_states = apply_rotary_pos_emb(
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user