update more tests + better hqq validation
This commit is contained in:
@@ -243,16 +243,11 @@ def normalize_config(cfg):
|
||||
elif cfg.quantization.bits == 4:
|
||||
cfg.load_in_4bit = True
|
||||
|
||||
elif cfg.quantization.backend == "gptq":
|
||||
if cfg.quantization.backend == "gptq":
|
||||
cfg.gptq = True
|
||||
elif cfg.quantization.backend == "hqq":
|
||||
cfg.hqq = True
|
||||
|
||||
if cfg.hqq and not cfg.quantization.hqq_config:
|
||||
raise ValueError(
|
||||
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
|
||||
)
|
||||
|
||||
|
||||
def normalize_cfg_datasets(cfg):
|
||||
"""
|
||||
|
||||
@@ -838,6 +838,7 @@ class ModelLoader:
|
||||
(not self.cfg.quantization)
|
||||
and (not self.cfg.load_in_8bit)
|
||||
and (not self.cfg.load_in_4bit)
|
||||
and not self.cfg.gptq
|
||||
):
|
||||
return
|
||||
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit
|
||||
|
||||
@@ -87,24 +87,24 @@ class LoraConfig(BaseModel):
|
||||
def validate_qlora(self):
|
||||
if self.adapter == "qlora":
|
||||
if self.merge_lora:
|
||||
if self.quantization.bits == 8:
|
||||
if self.quantization.bits == 8 or self.cfg.load_in_8bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if self.quantization.backend == "gptq":
|
||||
raise ValueError("Can't merge qlora if using gptq")
|
||||
|
||||
if self.quantization.bits == 4:
|
||||
if self.quantization.bits == 4 or self.cfg.load_in_4bit:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if self.quantization:
|
||||
if self.quantization.bits == 8:
|
||||
if self.quantization.bits == 8 or self.cfg.load_in_8bit:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if self.quantization.backend == "gptq":
|
||||
raise ValueError("Can't load qlora if using gptq")
|
||||
|
||||
if not self.quantization.bits == 4:
|
||||
if not self.quantization.bits == 4 or self.cfg.load_in_4bit:
|
||||
raise ValueError("Require quantization.bits <= 4 for qlora")
|
||||
|
||||
return self
|
||||
@@ -123,6 +123,24 @@ class LoraConfig(BaseModel):
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_hqq(cls, data):
|
||||
if (
|
||||
data.get("quantization")
|
||||
and data.get("quantization").get("backend") == "hqq"
|
||||
):
|
||||
if not data.get("quantization").get("hqq_config"):
|
||||
raise ValueError(
|
||||
"If using HQQ, must set `hqq_config` under `quantization`"
|
||||
)
|
||||
|
||||
if data.get("load_in_4bit") or data.get("load_in_8bit"):
|
||||
raise ValueError(
|
||||
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"quantization": {
|
||||
"backend": "gptq",
|
||||
},
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"gptq": True,
|
||||
|
||||
Reference in New Issue
Block a user