skip set_quant_config if quantization not given
This commit is contained in:
@@ -834,6 +834,8 @@ class ModelLoader:
|
|||||||
del self.model_kwargs["device_map"]
|
del self.model_kwargs["device_map"]
|
||||||
|
|
||||||
def set_quantization_config(self) -> None:
|
def set_quantization_config(self) -> None:
|
||||||
|
if not self.cfg.quantization:
|
||||||
|
return
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||||
|
|
||||||
@@ -887,7 +889,7 @@ class ModelLoader:
|
|||||||
# but deepspeed needs this still in bfloat16
|
# but deepspeed needs this still in bfloat16
|
||||||
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
bnb_config["bnb_4bit_quant_storage"] = torch.float32
|
||||||
|
|
||||||
if self.cfg.quantization and self.cfg.quantization.bnb_config_kwargs:
|
if self.cfg.quantization.bnb_config_kwargs:
|
||||||
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
|
bnb_config.update(self.cfg.quantization.bnb_config_kwargs)
|
||||||
|
|
||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
|
|||||||
Reference in New Issue
Block a user