qwen2_moe support w multipack (#1455)

This commit is contained in:
Wing Lian
2024-03-29 11:04:53 -04:00
committed by GitHub
parent 4a92a3b9ee
commit 6086be85f7
6 changed files with 147 additions and 4 deletions

View File

@@ -12,6 +12,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"gemma",
@@ -31,6 +32,10 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data