amend unittests pt2
This commit is contained in:
@@ -74,7 +74,11 @@ class TestValidation(BaseValidation):
|
||||
"deepspeed": "deepspeed_configs/zero3_bf16.json",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
# "load_in_4bit": True
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
@@ -93,7 +97,10 @@ class TestValidation(BaseValidation):
|
||||
"deepspeed": "",
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
@@ -107,7 +114,10 @@ class TestValidation(BaseValidation):
|
||||
"deepspeed": None,
|
||||
"gradient_checkpointing": True,
|
||||
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
}
|
||||
| minimal_cfg
|
||||
@@ -306,7 +316,10 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -318,7 +331,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"gptq": True,
|
||||
"quantization": {
|
||||
"backend": "gptq",
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -330,19 +345,24 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": False,
|
||||
"quantization": {
|
||||
"bits": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=r".*4bit.*"):
|
||||
with pytest.raises(ValueError, match=r".*bits <= 4*"):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -364,7 +384,10 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"backend": "bnb",
|
||||
"bits": 8,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -376,7 +399,10 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"gptq": True,
|
||||
"quantization": {
|
||||
"backend": "gptq",
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -388,7 +414,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"bits": 4,
|
||||
},
|
||||
}
|
||||
)
|
||||
| base_cfg
|
||||
@@ -976,7 +1004,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"bits": None,
|
||||
},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
@@ -984,29 +1014,16 @@ class TestValidation(BaseValidation):
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
match=r"Quantization is not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"quantization": {
|
||||
"bits": 4,
|
||||
},
|
||||
"adapter": "qlora",
|
||||
}
|
||||
)
|
||||
@@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"quantization": {
|
||||
"bits": 8,
|
||||
},
|
||||
"adapter": "lora",
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user