support for gemma2 w sample packing (#1718)
This commit is contained in:
@@ -16,6 +16,7 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
"gemma2",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
"deepseek_v2",
|
||||
@@ -49,6 +50,10 @@ def patch_for_multipack(model_type, model_name=None):
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemma2":
|
||||
transformers.models.gemma2.modeling_gemma2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
|
||||
Reference in New Issue
Block a user