beta support for multipack with gemmoe: (#1402)
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
"""multipack patching for v2 of sample packing"""
|
"""multipack patching for v2 of sample packing"""
|
||||||
|
import importlib
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||||
@@ -12,11 +15,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
|||||||
"falcon",
|
"falcon",
|
||||||
"phi",
|
"phi",
|
||||||
"gemma",
|
"gemma",
|
||||||
|
"gemmoe",
|
||||||
"starcoder2",
|
"starcoder2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def patch_for_multipack(model_type):
|
def patch_for_multipack(model_type, model_name=None):
|
||||||
if model_type == "mixtral":
|
if model_type == "mixtral":
|
||||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
@@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
|
|||||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
elif model_type == "gemmoe":
|
||||||
|
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
# we need to load the model here in order for modeling_gemmoe to be available
|
||||||
|
with init_empty_weights():
|
||||||
|
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||||
|
module_name = model_config.__class__.__module__.replace(
|
||||||
|
".configuration_gemmoe", ".modeling_gemmoe"
|
||||||
|
)
|
||||||
|
modeling_gemmoe = importlib.import_module(module_name)
|
||||||
|
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
|
get_unpad_data
|
||||||
|
)
|
||||||
|
|||||||
@@ -429,7 +429,7 @@ def load_model(
|
|||||||
and cfg.flash_attention
|
and cfg.flash_attention
|
||||||
and cfg.sample_packing
|
and cfg.sample_packing
|
||||||
):
|
):
|
||||||
patch_for_multipack(cfg.model_config_type)
|
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||||
elif cfg.is_llama_derived_model:
|
elif cfg.is_llama_derived_model:
|
||||||
# Modify all llama derived models in one block
|
# Modify all llama derived models in one block
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user