amend unittests pt2

This commit is contained in:
Sunny Liu
2025-04-21 13:28:52 -04:00
parent 9be971d47c
commit c8fb5baad6
2 changed files with 60 additions and 36 deletions

View File

@@ -78,7 +78,7 @@ class LoraConfig(BaseModel):
and (data.get("quantization")) and (data.get("quantization"))
): ):
raise ValueError( raise ValueError(
"Quantization is not supported without setting an adapter for training." "Quantization is not supported without setting an adapter."
"If you want to full finetune, please turn off Quantization." "If you want to full finetune, please turn off Quantization."
) )
return data return data
@@ -87,21 +87,26 @@ class LoraConfig(BaseModel):
def validate_qlora(self): def validate_qlora(self):
if self.adapter == "qlora": if self.adapter == "qlora":
if self.merge_lora: if self.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit if self.quantization.bits == 8:
if self.quantization: raise ValueError("Can't merge qlora if loaded in 8bit")
raise ValueError("Can't merge qlora if loaded in quantized model")
if self.quantization.backend == "gptq": if self.quantization.backend == "gptq":
raise ValueError("Can't merge qlora if using gptq") raise ValueError("Can't merge qlora if using gptq")
if self.quantization.bits == 4:
raise ValueError("Can't merge qlora if loaded in 4bit")
else: else:
if self.quantization: if self.quantization:
if self.quantization.bits > 4: if self.quantization.bits == 8:
raise ValueError("Can't load qlora in >4 bit") raise ValueError("Can't load qlora in 8bit")
if self.quantization.backend == "gptq": if self.quantization.backend == "gptq":
raise ValueError("Can't load qlora if using gptq") raise ValueError("Can't load qlora if using gptq")
if not self.quantization.bits == 4:
raise ValueError("Require quantization.bits <= 4 for qlora")
return self return self
@field_validator("loraplus_lr_embedding") @field_validator("loraplus_lr_embedding")

View File

@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
"deepspeed": "deepspeed_configs/zero3_bf16.json", "deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
# "load_in_4bit": True
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
"deepspeed": "", "deepspeed": "",
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
"deepspeed": None, "deepspeed": None,
"gradient_checkpointing": True, "gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False}, "gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
"adapter": "qlora", "adapter": "qlora",
} }
| minimal_cfg | minimal_cfg
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_8bit": True, "quantization": {
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"gptq": True, "quantization": {
"backend": "gptq",
},
} }
) )
| base_cfg | base_cfg
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_4bit": False, "quantization": {
"bits": None,
},
} }
) )
| base_cfg | base_cfg
) )
with pytest.raises(ValueError, match=r".*4bit.*"): with pytest.raises(ValueError, match=r".*bits <= 4*"):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_4bit": True, "quantization": {
"backend": "bnb",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_8bit": True, "quantization": {
"backend": "bnb",
"bits": 8,
},
} }
) )
| base_cfg | base_cfg
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"gptq": True, "quantization": {
"backend": "gptq",
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation DictDefault( # pylint: disable=unsupported-binary-operation
{ {
"load_in_4bit": True, "quantization": {
"bits": 4,
},
} }
) )
| base_cfg | base_cfg
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"load_in_4bit": True, "quantization": {
"bits": None,
},
} }
) )
| minimal_cfg | minimal_cfg
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", match=r"Quantization is not supported without setting an adapter.*",
): ):
validate_config(cfg) validate_config(cfg)
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"load_in_8bit": True, "quantization": {
} "bits": 4,
) },
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = (
DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora", "adapter": "qlora",
} }
) )
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
cfg = ( cfg = (
DictDefault( DictDefault(
{ {
"load_in_8bit": True, "quantization": {
"bits": 8,
},
"adapter": "lora", "adapter": "lora",
} }
) )