add starcoder2 (#1349)

* add starcoder2

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* chore: lint

* Apply suggestions from code review

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
Eric Hartford
2024-03-05 16:49:17 -08:00
committed by GitHub
parent 8984bf1722
commit e0f1895408
2 changed files with 81 additions and 1 deletions

View File

@@ -6,7 +6,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"falcon",
"phi",
"gemma",
"starcoder2",
]
def patch_for_multipack(model_type):
@@ -32,3 +39,7 @@ def patch_for_multipack(model_type):
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "starcoder2":
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)