From e07347b188d80a0a839cbb6989af266aaa0735a8 Mon Sep 17 00:00:00 2001 From: WenboPan Date: Wed, 27 Mar 2024 03:19:44 +0800 Subject: [PATCH] Remove seq_len arg in rotary_emb (#1443) * remove seq_len in llama rotary_emb * chore: lint --------- Co-authored-by: Wing Lian --- .../monkeypatch/llama_attn_hijack_flash.py | 15 ++------------- .../monkeypatch/llama_attn_hijack_xformers.py | 6 +----- 2 files changed, 3 insertions(+), 18 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index f727c74b8..dda5da2b7 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -284,12 +284,7 @@ def flashattn_forward_with_s2attn( # [bsz, nh, q_len, hd] # pylint: disable=duplicate-code - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - cos, sin = self.rotary_emb( - value_states, seq_len=kv_seq_len, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) @@ -435,13 +430,7 @@ def flashattn_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb( - value_states, seq_len=kv_seq_len, position_ids=position_ids - ) + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids ) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index 8143750f0..0c1a4e822 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -80,11 +80,7 @@ def xformers_forward( # [bsz, q_len, nh, hd] # [bsz, nh, q_len, hd] - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids )