fix for fsdp for models that aren't qwen2 or jamba
This commit is contained in:
@@ -459,7 +459,7 @@ def load_model(
|
||||
"bnb_4bit_quant_type": "nf4",
|
||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||
}
|
||||
if not cfg.deepspeed:
|
||||
if not cfg.deepspeed and cfg.model_config_type in ("jamba", "qwen2_moe"):
|
||||
# for some reason, this causes the loss to be off by an order of magnitude
|
||||
# but deepspeed needs this still in bfloat16
|
||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||
|
||||
Reference in New Issue
Block a user