beta support for multipack with gemmoe: (#1402)
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
import importlib
|
||||
|
||||
import transformers
|
||||
from accelerate import init_empty_weights
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
@@ -12,11 +15,12 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
|
||||
"falcon",
|
||||
"phi",
|
||||
"gemma",
|
||||
"gemmoe",
|
||||
"starcoder2",
|
||||
]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
def patch_for_multipack(model_type, model_name=None):
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
@@ -43,3 +47,15 @@ def patch_for_multipack(model_type):
|
||||
transformers.models.starcoder2.modeling_starcoder2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "gemmoe":
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_gemmoe to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_gemmoe", ".modeling_gemmoe"
|
||||
)
|
||||
modeling_gemmoe = importlib.import_module(module_name)
|
||||
modeling_gemmoe._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
|
||||
@@ -429,7 +429,7 @@ def load_model(
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type)
|
||||
patch_for_multipack(cfg.model_config_type, model_name=cfg.base_model)
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
|
||||
Reference in New Issue
Block a user