qwen2_moe support w multipack (#1455)

This commit is contained in:
Wing Lian
2024-03-29 11:04:53 -04:00
committed by GitHub
parent 4a92a3b9ee
commit 6086be85f7
6 changed files with 147 additions and 4 deletions

View File

@@ -456,7 +456,7 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type == "jamba" and not cfg.deepspeed:
if not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32