multipack for gemma (#1313)

* multipack for gemma

* chore: lint

* handle cache_position kwarg in updated llama modeling

* add position_ids to rotary embed call for updated llama modeling
This commit is contained in:
Wing Lian
2024-02-21 19:24:21 -05:00
committed by GitHub
parent 9e300aca0c
commit 2752d5f958
4 changed files with 26 additions and 15 deletions

View File

@@ -275,7 +275,9 @@ def flashattn_forward_with_s2attn(
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -425,7 +427,9 @@ def flashattn_forward(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
cos, sin = self.rotary_emb(
value_states, seq_len=kv_seq_len, position_ids=position_ids
)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
@@ -688,6 +692,9 @@ def llama_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[ # pylint: disable=unused-argument
torch.LongTensor
] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = (
output_attentions

View File

@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
def patch_for_multipack(model_type):
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma":
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)