From 970b2a6f2fcd5c364d5e8a9a6bf63ed3671569b4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 16 Feb 2026 21:26:04 +0700 Subject: [PATCH] feat: test for config validation and BC for new peft weight dtype --- .../utils/lora/test_config_validation_lora.py | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) diff --git a/tests/utils/lora/test_config_validation_lora.py b/tests/utils/lora/test_config_validation_lora.py index 9d97288b6..0f98e38c2 100644 --- a/tests/utils/lora/test_config_validation_lora.py +++ b/tests/utils/lora/test_config_validation_lora.py @@ -3,6 +3,14 @@ import pytest from axolotl.utils.config import validate_config from axolotl.utils.dict import DictDefault +BASE_CFG = { + "datasets": [{"path": "dummy_dataset", "type": "alpaca"}], + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "learning_rate": 1e-5, + "base_model": "dummy_model", +} + class TestLoRAConfigValidation: """Test suite for LoRA/QLoRA configuration validation""" @@ -149,3 +157,195 @@ class TestLoRAConfigValidation: result = validate_config(valid_config) assert result["lora_qkv_kernel"] is True assert result["trust_remote_code"] is None + + +class TestTorchaoQLoRAConfigValidation: + """Test suite for torchao QLoRA auto-detection and validation""" + + # --- Auto-detection: torchao --- + + @pytest.mark.parametrize("weight_dtype", ["int4", "int8", "nf4"]) + def test_torchao_auto_detect_from_lora(self, weight_dtype): + """adapter: lora + peft.backend: torchao auto-upgrades to qlora""" + cfg = DictDefault( + { + "adapter": "lora", + "peft": {"backend": "torchao", "weight_dtype": weight_dtype}, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "qlora" + assert result["peft"]["backend"] == "torchao" + + def test_torchao_explicit_qlora(self): + """adapter: qlora + peft.backend: torchao works directly""" + cfg = DictDefault( + { + "adapter": "qlora", + "peft": {"backend": "torchao", "weight_dtype": "int4"}, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "qlora" + + # --- Auto-detection: bnb --- + + def test_bnb_nf4_auto_detect_from_lora(self): + """adapter: lora + peft.backend: bnb + weight_dtype: nf4 → qlora + load_in_4bit""" + cfg = DictDefault( + { + "adapter": "lora", + "peft": {"backend": "bnb", "weight_dtype": "nf4"}, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "qlora" + assert result["load_in_4bit"] is True + + def test_bnb_int8_auto_detect_from_lora(self): + """adapter: lora + peft.backend: bnb + weight_dtype: int8 → lora + load_in_8bit""" + cfg = DictDefault( + { + "adapter": "lora", + "peft": {"backend": "bnb", "weight_dtype": "int8"}, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "lora" + assert result["load_in_8bit"] is True + + def test_bnb_nf4_explicit_qlora_auto_sets_load_in_4bit(self): + """adapter: qlora + peft.backend: bnb + weight_dtype: nf4 auto-sets load_in_4bit""" + cfg = DictDefault( + { + "adapter": "qlora", + "peft": {"backend": "bnb", "weight_dtype": "nf4"}, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "qlora" + assert result["load_in_4bit"] is True + + # --- Backward compat --- + + def test_old_style_qlora_unchanged(self): + """Old-style adapter: qlora + load_in_4bit: true still works""" + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "qlora" + assert result["load_in_4bit"] is True + + def test_old_style_lora_8bit_unchanged(self): + """Old-style adapter: lora + load_in_8bit: true still works""" + cfg = DictDefault( + { + "adapter": "lora", + "load_in_8bit": True, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "lora" + assert result["load_in_8bit"] is True + + def test_plain_lora_unchanged(self): + """adapter: lora without peft block stays as lora""" + cfg = DictDefault( + { + "adapter": "lora", + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "lora" + + # --- Validation errors --- + + def test_torchao_with_load_in_4bit_errors(self): + """peft.backend: torchao + load_in_4bit is a conflict""" + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_4bit": True, + "peft": {"backend": "torchao", "weight_dtype": "int4"}, + **BASE_CFG, + } + ) + with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"): + validate_config(cfg) + + def test_torchao_with_load_in_8bit_errors(self): + """peft.backend: torchao + load_in_8bit is a conflict""" + cfg = DictDefault( + { + "adapter": "qlora", + "load_in_8bit": True, + "peft": {"backend": "torchao", "weight_dtype": "int4"}, + **BASE_CFG, + } + ) + with pytest.raises(ValueError, match="load_in_4bit.*bitsandbytes"): + validate_config(cfg) + + def test_torchao_without_weight_dtype_errors(self): + """peft.backend: torchao without weight_dtype errors""" + cfg = DictDefault( + { + "adapter": "qlora", + "peft": {"backend": "torchao"}, + **BASE_CFG, + } + ) + with pytest.raises(ValueError, match="peft.weight_dtype is required"): + validate_config(cfg) + + def test_weight_dtype_without_backend_errors(self): + """peft.weight_dtype without peft.backend errors""" + cfg = DictDefault( + { + "adapter": "lora", + "peft": {"weight_dtype": "int4"}, + **BASE_CFG, + } + ) + with pytest.raises(ValueError, match="peft.backend is required"): + validate_config(cfg) + + def test_bnb_unsupported_weight_dtype_errors(self): + """peft.backend: bnb + unsupported weight_dtype errors""" + cfg = DictDefault( + { + "adapter": "lora", + "peft": {"backend": "bnb", "weight_dtype": "int4"}, + **BASE_CFG, + } + ) + with pytest.raises(ValueError, match="not supported with bnb"): + validate_config(cfg) + + # --- Redundant flags don't conflict --- + + def test_bnb_nf4_with_explicit_load_in_4bit(self): + """peft.backend: bnb + weight_dtype: nf4 + load_in_4bit: true is fine (redundant)""" + cfg = DictDefault( + { + "adapter": "lora", + "load_in_4bit": True, + "peft": {"backend": "bnb", "weight_dtype": "nf4"}, + **BASE_CFG, + } + ) + result = validate_config(cfg) + assert result["adapter"] == "qlora" + assert result["load_in_4bit"] is True