qwen2_moe support w multipack (#1455)

This commit is contained in:
Wing Lian
2024-03-29 11:04:53 -04:00
committed by GitHub
parent 4a92a3b9ee
commit 6086be85f7
6 changed files with 147 additions and 4 deletions

View File

@@ -12,6 +12,7 @@ from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = [
"mixtral",
"qwen2",
"qwen2_moe",
"falcon",
"phi",
"gemma",
@@ -31,6 +32,10 @@ def patch_for_multipack(model_type, model_name=None):
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "qwen2_moe":
transformers.models.qwen2_moe.modeling_qwen2_moe._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data

View File

@@ -456,7 +456,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type == "jamba" and not cfg.deepspeed:
if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32