add starcoder2 (#1349)
* add starcoder2 * Apply suggestions from code review Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> * chore: lint * Apply suggestions from code review Co-authored-by: NanoCode012 <kevinvong@rocketmail.com> --------- Co-authored-by: Wing Lian <wing.lian@gmail.com> Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
This commit is contained in:
@@ -6,7 +6,14 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi", "gemma"]
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"mixtral",
|
||||
"qwen2",
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
"starcoder2",
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
@@ -32,3 +39,7 @@ def patch_for_multipack(model_type):
|
||||
transformers.models.gemma.modeling_gemma._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "starcoder2":
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user