use uint8 dtype for qlora
This commit is contained in:
@@ -912,7 +912,7 @@ class ModelLoader:
|
|||||||
"bnb_4bit_compute_dtype": self.cfg.torch_dtype,
|
"bnb_4bit_compute_dtype": self.cfg.torch_dtype,
|
||||||
"bnb_4bit_use_double_quant": True,
|
"bnb_4bit_use_double_quant": True,
|
||||||
"bnb_4bit_quant_type": "nf4",
|
"bnb_4bit_quant_type": "nf4",
|
||||||
"bnb_4bit_quant_storage": torch.bfloat16,
|
"bnb_4bit_quant_storage": torch.uint8,
|
||||||
}
|
}
|
||||||
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
|
||||||
self.cfg.deepspeed or self.cfg.fsdp
|
self.cfg.deepspeed or self.cfg.fsdp
|
||||||
|
|||||||
Reference in New Issue
Block a user