assume empty lora dropout means 0.0 and add tests (#2243)
* assume empty lora dropout means 0.0 and add tests * remove un-necessary arg * refactor based on pr feedback: * chore: lint
This commit is contained in:
@@ -367,6 +367,13 @@ class LoraConfig(BaseModel):
|
||||
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||
return loraplus_lr_embedding
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_lora_dropout(cls, data):
|
||||
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
Reference in New Issue
Block a user