fix keyerror on load_in_8bit/load_in_4bit access in _set_quantization_config (#3023)
* set load_in_8bit/load_in_4bit in _set_quantization_config to prevent keyerror * use dict.get instead
This commit is contained in:
@@ -612,7 +612,9 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
|
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||||
|
"load_in_4bit", False
|
||||||
|
):
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -638,7 +640,9 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
|
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||||
|
"load_in_8bit", False
|
||||||
|
):
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user