support for gemma2 w sample packing (#1718)

This commit is contained in:
Wing Lian
2024-06-29 01:38:55 -04:00
committed by GitHub
parent f2480a1d91
commit 5370cedf0c
9 changed files with 97 additions and 5 deletions

View File

@@ -16,6 +16,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"falcon",
"phi",
"gemma",
"gemma2",
"gemmoe",
"starcoder2",
"deepseek_v2",
@@ -49,6 +50,10 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "gemma2":
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data