fix: transformers deprecate load_in_Xbit in model_kwargs (#3205)
* fix: transformers deprecate load_in_Xbit in model_kwargs * fix: test to read from quantization_config kwarg * fix: test * fix: access * fix: test weirdly entering incorrect config
This commit is contained in:
@@ -515,9 +515,6 @@ class ModelLoader:
|
||||
if self.cfg.model_quantization_config_kwargs:
|
||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||
else:
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
||||
|
||||
if self.cfg.gptq:
|
||||
if not hasattr(self.model_config, "quantization_config"):
|
||||
@@ -552,9 +549,7 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**self.model_config.quantization_config
|
||||
)
|
||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
||||
"load_in_4bit", False
|
||||
):
|
||||
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
||||
bnb_config = {
|
||||
"load_in_4bit": True,
|
||||
"llm_int8_threshold": 6.0,
|
||||
@@ -580,9 +575,7 @@ class ModelLoader:
|
||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
**bnb_config,
|
||||
)
|
||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
||||
"load_in_8bit", False
|
||||
):
|
||||
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
|
||||
bnb_config = {
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -596,11 +589,6 @@ class ModelLoader:
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
||||
self.model_kwargs.pop("load_in_8bit", None)
|
||||
self.model_kwargs.pop("load_in_4bit", None)
|
||||
|
||||
def _set_attention_config(self):
|
||||
"""Sample packing uses custom FA2 patch"""
|
||||
if self.cfg.attn_implementation:
|
||||
|
||||
@@ -80,16 +80,26 @@ class TestModelsUtils:
|
||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||
)
|
||||
elif load_in_8bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
||||
elif load_in_4bit and self.cfg.adapter is not None:
|
||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
||||
|
||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
||||
self.cfg.adapter == "lora" and load_in_8bit
|
||||
):
|
||||
assert self.model_loader.model_kwargs.get(
|
||||
"quantization_config", BitsAndBytesConfig
|
||||
if self.cfg.adapter == "qlora" and load_in_4bit:
|
||||
assert isinstance(
|
||||
self.model_loader.model_kwargs.get("quantization_config"),
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
|
||||
is True
|
||||
)
|
||||
if self.cfg.adapter == "lora" and load_in_8bit:
|
||||
assert isinstance(
|
||||
self.model_loader.model_kwargs.get("quantization_config"),
|
||||
BitsAndBytesConfig,
|
||||
)
|
||||
|
||||
assert (
|
||||
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
|
||||
is True
|
||||
)
|
||||
|
||||
def test_message_property_mapping(self):
|
||||
|
||||
Reference in New Issue
Block a user