amend unittests pt2
This commit is contained in:
@@ -78,7 +78,7 @@ class LoraConfig(BaseModel):
|
||||
and (data.get("quantization"))
|
||||
):
|
||||
raise ValueError(
|
||||
"Quantization is not supported without setting an adapter for training."
|
||||
"Quantization is not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off Quantization."
|
||||
)
|
||||
return data
|
||||
@@ -87,21 +87,26 @@ class LoraConfig(BaseModel):
|
||||
def validate_qlora(self):
|
||||
if self.adapter == "qlora":
|
||||
if self.merge_lora:
|
||||
# can't merge qlora if loaded in 8bit or 4bit
|
||||
if self.quantization:
|
||||
raise ValueError("Can't merge qlora if loaded in quantized model")
|
||||
if self.quantization.bits == 8:
|
||||
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||
|
||||
if self.quantization.backend == "gptq":
|
||||
raise ValueError("Can't merge qlora if using gptq")
|
||||
|
||||
if self.quantization.bits == 4:
|
||||
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||
|
||||
else:
|
||||
if self.quantization:
|
||||
if self.quantization.bits > 4:
|
||||
raise ValueError("Can't load qlora in >4 bit")
|
||||
if self.quantization.bits == 8:
|
||||
raise ValueError("Can't load qlora in 8bit")
|
||||
|
||||
if self.quantization.backend == "gptq":
|
||||
raise ValueError("Can't load qlora if using gptq")
|
||||
|
||||
if not self.quantization.bits == 4:
|
||||
raise ValueError("Require quantization.bits <= 4 for qlora")
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("loraplus_lr_embedding")
|
||||
|
||||
Reference in New Issue
Block a user