update more tests + better hqq validation

This commit is contained in:
Sunny Liu
2025-04-21 22:17:08 -04:00
parent db7e92f6a6
commit c4910da015
4 changed files with 27 additions and 10 deletions

View File

@@ -243,16 +243,11 @@ def normalize_config(cfg):
elif cfg.quantization.bits == 4:
cfg.load_in_4bit = True
elif cfg.quantization.backend == "gptq":
if cfg.quantization.backend == "gptq":
cfg.gptq = True
elif cfg.quantization.backend == "hqq":
cfg.hqq = True
if cfg.hqq and not cfg.quantization.hqq_config:
raise ValueError(
"If using HQQ, must set `hqq_config` to a list of HQQConfig objects"
)
def normalize_cfg_datasets(cfg):
"""

View File

@@ -838,6 +838,7 @@ class ModelLoader:
(not self.cfg.quantization)
and (not self.cfg.load_in_8bit)
and (not self.cfg.load_in_4bit)
and not self.cfg.gptq
):
return
self.model_kwargs["load_in_8bit"] = self.cfg.load_in_8bit

View File

@@ -87,24 +87,24 @@ class LoraConfig(BaseModel):
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
if self.quantization.bits == 8:
if self.quantization.bits == 8 or self.cfg.load_in_8bit:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq")
if self.quantization.bits == 4:
if self.quantization.bits == 4 or self.cfg.load_in_4bit:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.quantization:
if self.quantization.bits == 8:
if self.quantization.bits == 8 or self.cfg.load_in_8bit:
raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4:
if not self.quantization.bits == 4 or self.cfg.load_in_4bit:
raise ValueError("Require quantization.bits <= 4 for qlora")
return self
@@ -123,6 +123,24 @@ class LoraConfig(BaseModel):
data["lora_dropout"] = 0.0
return data
@model_validator(mode="before")
@classmethod
def validate_hqq(cls, data):
if (
data.get("quantization")
and data.get("quantization").get("backend") == "hqq"
):
if not data.get("quantization").get("hqq_config"):
raise ValueError(
"If using HQQ, must set `hqq_config` under `quantization`"
)
if data.get("load_in_4bit") or data.get("load_in_8bit"):
raise ValueError(
"If using HQQ quantization, please remove load_in_4bit or load_in_8bit"
)
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""

View File

@@ -89,6 +89,9 @@ class TestLoraLlama(unittest.TestCase):
"sequence_len": 1024,
"sample_packing": True,
"flash_attention": True,
"quantization": {
"backend": "gptq",
},
"load_in_8bit": True,
"adapter": "lora",
"gptq": True,