fix: transformers deprecate load_in_Xbit in model_kwargs (#3205)
* fix: transformers deprecate load_in_Xbit in model_kwargs * fix: test to read from quantization_config kwarg * fix: test * fix: access * fix: test weirdly entering incorrect config
This commit is contained in:
@@ -515,9 +515,6 @@ class ModelLoader:
|
|||||||
if self.cfg.model_quantization_config_kwargs:
|
if self.cfg.model_quantization_config_kwargs:
|
||||||
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
mxfp4_kwargs = self.cfg.model_quantization_config_kwargs
|
||||||
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
self.model_kwargs["quantization_config"] = Mxfp4Config(**mxfp4_kwargs)
|
||||||
else:
|
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
|
||||||
self.model_kwargs["load_in_4bit"] = self.cfg.load_in_4bit
|
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
if not hasattr(self.model_config, "quantization_config"):
|
if not hasattr(self.model_config, "quantization_config"):
|
||||||
@@ -552,9 +549,7 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**self.model_config.quantization_config
|
**self.model_config.quantization_config
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "qlora" and self.model_kwargs.get(
|
elif self.cfg.adapter == "qlora" and self.cfg.load_in_4bit:
|
||||||
"load_in_4bit", False
|
|
||||||
):
|
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_4bit": True,
|
"load_in_4bit": True,
|
||||||
"llm_int8_threshold": 6.0,
|
"llm_int8_threshold": 6.0,
|
||||||
@@ -580,9 +575,7 @@ class ModelLoader:
|
|||||||
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
elif self.cfg.adapter == "lora" and self.model_kwargs.get(
|
elif self.cfg.adapter == "lora" and self.cfg.load_in_8bit:
|
||||||
"load_in_8bit", False
|
|
||||||
):
|
|
||||||
bnb_config = {
|
bnb_config = {
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
}
|
}
|
||||||
@@ -596,11 +589,6 @@ class ModelLoader:
|
|||||||
**bnb_config,
|
**bnb_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
|
||||||
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
|
|
||||||
self.model_kwargs.pop("load_in_8bit", None)
|
|
||||||
self.model_kwargs.pop("load_in_4bit", None)
|
|
||||||
|
|
||||||
def _set_attention_config(self):
|
def _set_attention_config(self):
|
||||||
"""Sample packing uses custom FA2 patch"""
|
"""Sample packing uses custom FA2 patch"""
|
||||||
if self.cfg.attn_implementation:
|
if self.cfg.attn_implementation:
|
||||||
|
|||||||
@@ -80,16 +80,26 @@ class TestModelsUtils:
|
|||||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||||
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
and hasattr(self.model_loader.model_kwargs, "load_in_4bit")
|
||||||
)
|
)
|
||||||
elif load_in_8bit and self.cfg.adapter is not None:
|
|
||||||
assert self.model_loader.model_kwargs["load_in_8bit"]
|
|
||||||
elif load_in_4bit and self.cfg.adapter is not None:
|
|
||||||
assert self.model_loader.model_kwargs["load_in_4bit"]
|
|
||||||
|
|
||||||
if (self.cfg.adapter == "qlora" and load_in_4bit) or (
|
if self.cfg.adapter == "qlora" and load_in_4bit:
|
||||||
self.cfg.adapter == "lora" and load_in_8bit
|
assert isinstance(
|
||||||
):
|
self.model_loader.model_kwargs.get("quantization_config"),
|
||||||
assert self.model_loader.model_kwargs.get(
|
BitsAndBytesConfig,
|
||||||
"quantization_config", BitsAndBytesConfig
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.model_loader.model_kwargs["quantization_config"]._load_in_4bit
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
if self.cfg.adapter == "lora" and load_in_8bit:
|
||||||
|
assert isinstance(
|
||||||
|
self.model_loader.model_kwargs.get("quantization_config"),
|
||||||
|
BitsAndBytesConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
self.model_loader.model_kwargs["quantization_config"]._load_in_8bit
|
||||||
|
is True
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_message_property_mapping(self):
|
def test_message_property_mapping(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user