fix: dont change quant storage dtype in case of fsdp (#1837)
* fix: dont change quant storage dtype in case of fsdp * fix black --------- Co-authored-by: Gal Cohen <galc@ai21.com>
This commit is contained in:
committed by
GitHub
parent
e29931259b
commit
5aac4bc284
@@ -544,7 +544,9 @@ def load_model(
|
|||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.bfloat16,
|
||||||
}
|
}
|
||||||
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not cfg.deepspeed:
|
if cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||||
|
cfg.deepspeed or cfg.fsdp
|
||||||
|
):
|
||||||
# for some reason, this causes the loss to be off by an order of magnitude
|
# for some reason, this causes the loss to be off by an order of magnitude
|
||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
|||||||
Reference in New Issue
Block a user