diff --git a/src/axolotl/utils/schemas/quant.py b/src/axolotl/utils/schemas/quant.py index 0b9d2dd0f..4714d0799 100644 --- a/src/axolotl/utils/schemas/quant.py +++ b/src/axolotl/utils/schemas/quant.py @@ -11,7 +11,7 @@ from pydantic import BaseModel, Field, model_validator class HQQConfig(BaseModel): """HQQ configuration subset""" - n_bits: int | None = Field(default=None) + nbits: int | None = Field(default=None) group_size: int | None = Field(default=None) target_modules: list[str] | str | None = Field( default=None, @@ -51,7 +51,7 @@ def get_hqq_quant_config_kwargs(cfg): # If no target module is specified, then target the whole model if len(cfg.hqq_config) == 1 and cfg.hqq_config[0].target_modules is None: return { - "nbits": cfg.hqq_config[0].n_bits, + "nbits": cfg.hqq_config[0].nbits, "group_size": cfg.hqq_config[0].group_size, } @@ -63,8 +63,8 @@ def get_hqq_quant_config_kwargs(cfg): for module in target_modules: hqq_quant_config_kwargs["dynamic_config"][module] = { - "nbits": hqq_config.hqq_nbits, - "group_size": hqq_config.hqq_group_size, + "nbits": hqq_config.nbits, + "group_size": hqq_config.group_size, } return hqq_quant_config_kwargs diff --git a/tests/e2e/test_quantization.py b/tests/e2e/test_quantization.py index f067df8f5..f6dab452a 100644 --- a/tests/e2e/test_quantization.py +++ b/tests/e2e/test_quantization.py @@ -36,11 +36,11 @@ class TestHQQ(unittest.TestCase): "use_hqq": True, "hqq_config": [ { - "nbits": 4, + "nbits": 8, "group_size": 32, } ], - "lora_adapter": "qlora", + "adapter": "lora", "lora_r": 16, "lora_alpha": 32, "lora_dropout": 0.05,