fsdp requires params be the same type too (#493)

This commit is contained in:
Wing Lian
2023-08-28 04:33:50 -04:00
committed by GitHub
parent 4c37bd0b54
commit 98bf76e236

View File

@@ -356,7 +356,7 @@ def load_model(
if hasattr(module, "weight"):
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter is not None
needs_fa2_dtype = cfg.adapter or cfg.fsdp
if not cfg.gptq and (
(cfg.adapter == "lora" and load_in_8bit)
or (cfg.adapter == "qlora" and cfg.load_in_4bit)