diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 4563d6a8f..e8dacbe40 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -834,7 +834,11 @@ class ModelLoader: del self.model_kwargs["device_map"] def set_quantization_config(self) -> None: - if not self.cfg.quantization: + if ( + (not self.cfg.quantization) + and (not self.cfg.load_in_8bit) + and (not self.cfg.load_in_4bit) + ): return self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit