From aa1240acd8d7e9640a01a78e9da8a0725b158041 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 16 Oct 2025 16:07:27 +0700 Subject: [PATCH] fix: transformers deprecate load_in_Xbit in model_kwargs (#3205) * fix: transformers deprecate load_in_Xbit in model_kwargs * fix: test to read from quantization_config kwarg * fix: test * fix: access * fix: test weirdly entering incorrect config --- src/axolotl/loaders/model.py | 16 ++-------------- tests/test_loaders.py | 28 +++++++++++++++++++--------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index f438d6b61..aeec46584 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -515,9 +515,6 @@ class ModelLoader: if self.cfg.model_quantization_config_kwargs: mxfp4_kwargs = self.cfg.model_quantization_config_kwargs self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs) - else: - self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit - self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit if self.cfg.gptq: if not hasattr(self.model_config, "quantization_config"): @@ -552,9 +549,7 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **self.model_config.quantization_config ) - elif self.cfg.adapter == "qlora" and self.model_kwargs.get( - "load_in_4bit", False - ): + elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit: bnb_config = { "load_in_4bit": True, "llm_int8_threshold": 6.0, @@ -580,9 +575,7 @@ class ModelLoader: self.model_kwargs["quantization_config"] = BitsAndBytesConfig( **bnb_config, ) - elif self.cfg.adapter == "lora" and self.model_kwargs.get( - "load_in_8bit", False - ): + elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit: bnb_config = { "load_in_8bit": True, } @@ -596,11 +589,6 @@ class ModelLoader: **bnb_config, ) - # no longer needed per https://github.com/huggingface/transformers/pull/26610 - if "quantization_config" in self.model_kwargs or self.cfg.gptq: - self.model_kwargs.pop("load_in_8bit", None) - self.model_kwargs.pop("load_in_4bit", None) - def _set_attention_config(self): """Sample packing uses custom FA2 patch""" if self.cfg.attn_implementation: diff --git a/tests/test_loaders.py b/tests/test_loaders.py index f516d0ca4..913090566 100644 --- a/tests/test_loaders.py +++ b/tests/test_loaders.py @@ -80,16 +80,26 @@ class TestModelsUtils: hasattr(self.model_loader.model_kwargs, "load_in_8bit") and hasattr(self.model_loader.model_kwargs, "load_in_4bit") ) - elif load_in_8bit and self.cfg.adapter is not None: - assert self.model_loader.model_kwargs["load_in_8bit"] - elif load_in_4bit and self.cfg.adapter is not None: - assert self.model_loader.model_kwargs["load_in_4bit"] - if (self.cfg.adapter == "qlora" and load_in_4bit) or ( - self.cfg.adapter == "lora" and load_in_8bit - ): - assert self.model_loader.model_kwargs.get( - "quantization_config", BitsAndBytesConfig + if self.cfg.adapter == "qlora" and load_in_4bit: + assert isinstance( + self.model_loader.model_kwargs.get("quantization_config"), + BitsAndBytesConfig, + ) + + assert ( + self.model_loader.model_kwargs["quantization_config"]._load_in_4bit + is True + ) + if self.cfg.adapter == "lora" and load_in_8bit: + assert isinstance( + self.model_loader.model_kwargs.get("quantization_config"), + BitsAndBytesConfig, + ) + + assert ( + self.model_loader.model_kwargs["quantization_config"]._load_in_8bit + is True ) def test_message_property_mapping(self):