qwen2_moe support w multipack (#1455)
This commit is contained in:
@@ -12,6 +12,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"qwen2_moe",
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
@@ -31,6 +32,10 @@ def patch_for_multipack(model_type, model_name=None):
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "qwen2_moe":
|
||||
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
|
||||
Reference in New Issue
Block a user