amend unittests pt2
This commit is contained in:
@@ -78,7 +78,7 @@ class LoraConfig(BaseModel):
|
|||||||
and (data.get("quantization"))
|
and (data.get("quantization"))
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Quantization is not supported without setting an adapter for training."
|
"Quantization is not supported without setting an adapter."
|
||||||
"If you want to full finetune, please turn off Quantization."
|
"If you want to full finetune, please turn off Quantization."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
@@ -87,21 +87,26 @@ class LoraConfig(BaseModel):
|
|||||||
def validate_qlora(self):
|
def validate_qlora(self):
|
||||||
if self.adapter == "qlora":
|
if self.adapter == "qlora":
|
||||||
if self.merge_lora:
|
if self.merge_lora:
|
||||||
# can't merge qlora if loaded in 8bit or 4bit
|
if self.quantization.bits == 8:
|
||||||
if self.quantization:
|
raise ValueError("Can't merge qlora if loaded in 8bit")
|
||||||
raise ValueError("Can't merge qlora if loaded in quantized model")
|
|
||||||
|
|
||||||
if self.quantization.backend == "gptq":
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't merge qlora if using gptq")
|
raise ValueError("Can't merge qlora if using gptq")
|
||||||
|
|
||||||
|
if self.quantization.bits == 4:
|
||||||
|
raise ValueError("Can't merge qlora if loaded in 4bit")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.quantization:
|
if self.quantization:
|
||||||
if self.quantization.bits > 4:
|
if self.quantization.bits == 8:
|
||||||
raise ValueError("Can't load qlora in >4 bit")
|
raise ValueError("Can't load qlora in 8bit")
|
||||||
|
|
||||||
if self.quantization.backend == "gptq":
|
if self.quantization.backend == "gptq":
|
||||||
raise ValueError("Can't load qlora if using gptq")
|
raise ValueError("Can't load qlora if using gptq")
|
||||||
|
|
||||||
|
if not self.quantization.bits == 4:
|
||||||
|
raise ValueError("Require quantization.bits <= 4 for qlora")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@field_validator("loraplus_lr_embedding")
|
@field_validator("loraplus_lr_embedding")
|
||||||
|
|||||||
@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
|
# "load_in_4bit": True
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": "",
|
"deepspeed": "",
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
|
|||||||
"deepspeed": None,
|
"deepspeed": None,
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"gptq": True,
|
"quantization": {
|
||||||
|
"backend": "gptq",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_4bit": False,
|
"quantization": {
|
||||||
|
"bits": None,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
with pytest.raises(ValueError, match=r".*bits <= 4*"):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"backend": "bnb",
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"gptq": True,
|
"quantization": {
|
||||||
|
"backend": "gptq",
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||||
{
|
{
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"bits": 4,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| base_cfg
|
| base_cfg
|
||||||
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"load_in_4bit": True,
|
"quantization": {
|
||||||
|
"bits": None,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
| minimal_cfg
|
| minimal_cfg
|
||||||
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
match=r"Quantization is not supported without setting an adapter.*",
|
||||||
):
|
):
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
}
|
"bits": 4,
|
||||||
)
|
},
|
||||||
| minimal_cfg
|
|
||||||
)
|
|
||||||
|
|
||||||
with pytest.raises(
|
|
||||||
ValueError,
|
|
||||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
|
||||||
):
|
|
||||||
validate_config(cfg)
|
|
||||||
|
|
||||||
cfg = (
|
|
||||||
DictDefault(
|
|
||||||
{
|
|
||||||
"load_in_4bit": True,
|
|
||||||
"adapter": "qlora",
|
"adapter": "qlora",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
|
|||||||
cfg = (
|
cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
{
|
{
|
||||||
"load_in_8bit": True,
|
"quantization": {
|
||||||
|
"bits": 8,
|
||||||
|
},
|
||||||
"adapter": "lora",
|
"adapter": "lora",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user