fix some of the edge cases for Jamba (#1452)

* fix some of the edge cases for Jamba

* update requirements for jamba
This commit is contained in:
Wing Lian
2024-03-29 02:38:02 -04:00
committed by GitHub
parent e634118f90
commit 05b398a072
8 changed files with 92 additions and 17 deletions

View File

@@ -456,6 +456,10 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type == "jamba" and not cfg.deepspeed:
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32
if cfg.bnb_config_kwargs:
bnb_config.update(cfg.bnb_config_kwargs)