amend unittests pt2

This commit is contained in:
Sunny Liu
2025-04-21 13:28:52 -04:00
parent 9be971d47c
commit c8fb5baad6
2 changed files with 60 additions and 36 deletions

View File

@@ -78,7 +78,7 @@ class LoraConfig(BaseModel):
and (data.get("quantization"))
):
raise ValueError(
"Quantization is not supported without setting an adapter for training."
"Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off Quantization."
)
return data
@@ -87,21 +87,26 @@ class LoraConfig(BaseModel):
def validate_qlora(self):
if self.adapter == "qlora":
if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit
if self.quantization:
raise ValueError("Can't merge qlora if loaded in quantized model")
if self.quantization.bits == 8:
raise ValueError("Can't merge qlora if loaded in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq")
if self.quantization.bits == 4:
raise ValueError("Can't merge qlora if loaded in 4bit")
else:
if self.quantization:
if self.quantization.bits > 4:
raise ValueError("Can't load qlora in >4 bit")
if self.quantization.bits == 8:
raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4:
raise ValueError("Require quantization.bits <= 4 for qlora")
return self
@field_validator("loraplus_lr_embedding")