diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 0c4f1cf80..7c2877180 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -853,7 +853,7 @@ class ModelLoader: if ( self.cfg.adapter in ["qlora", "lora"] and hasattr(self.model_config, "quantization_config") - and getattr(self.model_config.quantization_config, "quant_method") + and self.model_config.quantization_config["quant_method"] in ["gptq", "awq", "bitsandbytes", "hqq"] ): quant_config_class_dict = { @@ -864,7 +864,7 @@ class ModelLoader: } quant_config_class = quant_config_class_dict[ - getattr(self.model_config.quantization_config, "quant_method") + self.model_config.quantization_config["quant_method"] ] self.model_kwargs["quantization_config"] = quant_config_class( **self.model_config.quantization_config