From 469e15607d7333bfd2f74f9ab15f46e7c012c076 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Jun 2024 14:39:55 -0400 Subject: [PATCH] basic llama multipack --- src/axolotl/monkeypatch/multipack.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/axolotl/monkeypatch/multipack.py b/src/axolotl/monkeypatch/multipack.py index 7f6296bb6..3855ec6e3 100644 --- a/src/axolotl/monkeypatch/multipack.py +++ b/src/axolotl/monkeypatch/multipack.py @@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3 from axolotl.monkeypatch.utils import get_unpad_data SUPPORTED_MULTIPACK_MODEL_TYPES = [ + "llama", "mixtral", "qwen2", "qwen2_moe", @@ -29,6 +30,10 @@ def patch_for_multipack(model_type, model_name=None): ) if is_deepspeed_zero3_enabled(): patch_mixtral_moe_forward_zero3() + elif model_type == "llama": + transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access + get_unpad_data + ) elif model_type == "qwen2": transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data