support for llama multipack using updated code/patches (#1754)

* support for llama multipack using updated code/patches

* also support unsloth patches

* incorrect arg

* add config validation for unsloth

* add missing return to validation

* add another missing return to validation
This commit is contained in:
Wing Lian
2024-07-16 17:36:29 -04:00
committed by GitHub
parent cfc533a7f7
commit 5f58555bd0
4 changed files with 95 additions and 21 deletions

View File

@@ -10,6 +10,7 @@ from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"llama",
"mixtral",
"qwen2",
"qwen2_moe",
@@ -30,6 +31,10 @@ def patch_for_multipack(model_type, model_name=None):
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "llama":
transformers.models.llama.modeling_llama._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data