multipack for gemma (#1313)
* multipack for gemma * chore: lint * handle cache_position kwarg in updated llama modeling * add position_ids to rotary embed call for updated llama modeling
This commit is contained in:
@@ -6,7 +6,7 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
@@ -28,3 +28,7 @@ def patch_for_multipack(model_type):
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma":
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user