update more tests + better hqq validation
This commit is contained in:
@@ -243,16 +243,11 @@ def normalize_config(cfg):
|
|||||||
elif cfg.quantization.bits == 4:
|
elif cfg.quantization.bits == 4:
|
||||||
cfg.load_in_4bit = True
|
cfg.load_in_4bit = True
|
||||||
|
|
||||||
elif cfg.quantization.backend == "gptq":
|
if cfg.quantization.backend == "gptq":
|
||||||
cfg.gptq = True
|
cfg.gptq = True
|
||||||
elif cfg.quantization.backend == "hqq":
|
elif cfg.quantization.backend == "hqq":
|
||||||
cfg.hqq = True
|
cfg.hqq = True
|
||||||
|
|
||||||
if cfg.hqq and not cfg.quantization.hqq_config:
|
|
||||||
raise ValueError(
|
|
||||||
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def normalize_cfg_datasets(cfg):
|
def normalize_cfg_datasets(cfg):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -838,6 +838,7 @@ class ModelLoader:
|
|||||||
(not self.cfg.quantization)
|
(not self.cfg.quantization)
|
||||||
and (not self.cfg.load_in_8bit)
|
and (not self.cfg.load_in_8bit)
|
||||||
and (not self.cfg.load_in_4bit)
|
and (not self.cfg.load_in_4bit)
|
||||||
|
and not self.cfg.gptq
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||||
|
|||||||
@@ -87,24 +87,24 @@ class LoraConfig(BaseModel):
|
|||||||
def validate_qlora(self):
|
def validate_qlora(self):
|
||||||
if self.adapter == "qlora":
|
if self.adapter == "qlora":
|
||||||
if self.merge_lora:
|
if self.merge_lora:
|
||||||
if self.quantization.bits == 8:
|
if self.quantization.bits == 8 or self.cfg.load_in_8bit:
|
||||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
|
|
||||||
if self.quantization.backend == "gptq":
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't merge qlora if using gptq")
|
raise ValueError("Can't merge qlora if using gptq")
|
||||||
|
|
||||||
if self.quantization.bits == 4:
|
if self.quantization.bits == 4 or self.cfg.load_in_4bit:
|
||||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.quantization:
|
if self.quantization:
|
||||||
if self.quantization.bits == 8:
|
if self.quantization.bits == 8 or self.cfg.load_in_8bit:
|
||||||
raise ValueError("Can't load qlora in 8bit")
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
if self.quantization.backend == "gptq":
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't load qlora if using gptq")
|
raise ValueError("Can't load qlora if using gptq")
|
||||||
|
|
||||||
if not self.quantization.bits == 4:
|
if not self.quantization.bits == 4 or self.cfg.load_in_4bit:
|
||||||
raise ValueError("Require quantization.bits <= 4 for qlora")
|
raise ValueError("Require quantization.bits <= 4 for qlora")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
@@ -123,6 +123,24 @@ class LoraConfig(BaseModel):
|
|||||||
data["lora_dropout"] = 0.0
|
data["lora_dropout"] = 0.0
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_hqq(cls, data):
|
||||||
|
if (
|
||||||
|
data.get("quantization")
|
||||||
|
and data.get("quantization").get("backend") == "hqq"
|
||||||
|
):
|
||||||
|
if not data.get("quantization").get("hqq_config"):
|
||||||
|
raise ValueError(
|
||||||
|
"If using HQQ, must set `hqq_config` under `quantization`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if data.get("load_in_4bit") or data.get("load_in_8bit"):
|
||||||
|
raise ValueError(
|
||||||
|
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ReLoRAConfig(BaseModel):
|
class ReLoRAConfig(BaseModel):
|
||||||
"""ReLoRA configuration subset"""
|
"""ReLoRA configuration subset"""
|
||||||
|
|||||||
@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"sample_packing": True,
|
"sample_packing": True,
|
||||||
"flash_attention": True,
|
"flash_attention": True,
|
||||||
|
"quantization": {
|
||||||
|
"backend": "gptq",
|
||||||
|
},
|
||||||
"load_in_8bit": True,
|
"load_in_8bit": True,
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
"gptq": True,
|
"gptq": True,
|
||||||
|
|||||||
Reference in New Issue
Block a user