diff --git a/src/axolotl/utils/schemas/peft.py b/src/axolotl/utils/schemas/peft.py index 4cf8dafe7..3cedec8b2 100644 --- a/src/axolotl/utils/schemas/peft.py +++ b/src/axolotl/utils/schemas/peft.py @@ -78,7 +78,7 @@ class LoraConfig(BaseModel): and (data.get("quantization")) ): raise ValueError( - "Quantization is not supported without setting an adapter for training." + "Quantization is not supported without setting an adapter." "If you want to full finetune, please turn off Quantization." ) return data @@ -87,21 +87,26 @@ class LoraConfig(BaseModel): def validate_qlora(self): if self.adapter == "qlora": if self.merge_lora: - # can't merge qlora if loaded in 8bit or 4bit - if self.quantization: - raise ValueError("Can't merge qlora if loaded in quantized model") + if self.quantization.bits == 8: + raise ValueError("Can't merge qlora if loaded in 8bit") if self.quantization.backend == "gptq": raise ValueError("Can't merge qlora if using gptq") + if self.quantization.bits == 4: + raise ValueError("Can't merge qlora if loaded in 4bit") + else: if self.quantization: - if self.quantization.bits > 4: - raise ValueError("Can't load qlora in >4 bit") + if self.quantization.bits == 8: + raise ValueError("Can't load qlora in 8bit") if self.quantization.backend == "gptq": raise ValueError("Can't load qlora if using gptq") + if not self.quantization.bits == 4: + raise ValueError("Require quantization.bits <= 4 for qlora") + return self @field_validator("loraplus_lr_embedding") diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 3262a6981..dc451b870 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -74,7 +74,11 @@ class TestValidation(BaseValidation): "deepspeed": "deepspeed_configs/zero3_bf16.json", "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": False}, - "load_in_4bit": True, + "quantization": { + "backend": "bnb", + "bits": 4, + }, + # "load_in_4bit": True "adapter": "qlora", } | minimal_cfg @@ -93,7 +97,10 @@ class TestValidation(BaseValidation): "deepspeed": "", "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": False}, - "load_in_4bit": True, + "quantization": { + "backend": "bnb", + "bits": 4, + }, "adapter": "qlora", } | minimal_cfg @@ -107,7 +114,10 @@ class TestValidation(BaseValidation): "deepspeed": None, "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": False}, - "load_in_4bit": True, + "quantization": { + "backend": "bnb", + "bits": 4, + }, "adapter": "qlora", } | minimal_cfg @@ -306,7 +316,10 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "load_in_8bit": True, + "quantization": { + "backend": "bnb", + "bits": 8, + }, } ) | base_cfg @@ -318,7 +331,9 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "gptq": True, + "quantization": { + "backend": "gptq", + }, } ) | base_cfg @@ -330,19 +345,24 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "load_in_4bit": False, + "quantization": { + "bits": None, + }, } ) | base_cfg ) - with pytest.raises(ValueError, match=r".*4bit.*"): + with pytest.raises(ValueError, match=r".*bits <= 4*"): validate_config(cfg) cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "load_in_4bit": True, + "quantization": { + "backend": "bnb", + "bits": 4, + }, } ) | base_cfg @@ -364,7 +384,10 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "load_in_8bit": True, + "quantization": { + "backend": "bnb", + "bits": 8, + }, } ) | base_cfg @@ -376,7 +399,10 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "gptq": True, + "quantization": { + "backend": "gptq", + "bits": 4, + }, } ) | base_cfg @@ -388,7 +414,9 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( # pylint: disable=unsupported-binary-operation { - "load_in_4bit": True, + "quantization": { + "bits": 4, + }, } ) | base_cfg @@ -976,7 +1004,9 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( { - "load_in_4bit": True, + "quantization": { + "bits": None, + }, } ) | minimal_cfg @@ -984,29 +1014,16 @@ class TestValidation(BaseValidation): with pytest.raises( ValueError, - match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", + match=r"Quantization is not supported without setting an adapter.*", ): validate_config(cfg) cfg = ( DictDefault( { - "load_in_8bit": True, - } - ) - | minimal_cfg - ) - - with pytest.raises( - ValueError, - match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*", - ): - validate_config(cfg) - - cfg = ( - DictDefault( - { - "load_in_4bit": True, + "quantization": { + "bits": 4, + }, "adapter": "qlora", } ) @@ -1018,7 +1035,9 @@ class TestValidation(BaseValidation): cfg = ( DictDefault( { - "load_in_8bit": True, + "quantization": { + "bits": 8, + }, "adapter": "lora", } )