fix: dont change quant storage dtype in case of fsdp (#1837)

* fix: dont change quant storage dtype in case of fsdp

* fix black

---------

Co-authored-by: Gal Cohen <galc@ai21.com>
This commit is contained in:
Gal Cohen (galco)
2024-08-20 19:41:48 +03:00
committed by GitHub
parent e29931259b
commit 5aac4bc284

View File

@@ -544,7 +544,9 @@ def load_model(
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.bfloat16,
}
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
cfg.deepspeed or cfg.fsdp
):
# for some reason, this causes the loss to be off by an order of magnitude
# but deepspeed needs this still in bfloat16
bnb_config["bnb_4bit_quant_storage"] = torch.float32