Fix: lora kernel pre-patch applied despite post-patch not applied (#2772)

* fix: do not pre-patch self attention if lora dropout non-zero

* fix: add test to check patch not applied

* fix: test

* fix: test config check

* fix where we check so that tests don't break

* fix: test

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2025-06-14 11:54:06 -07:00
committed by GitHub
parent 80d5b066ec
commit 21388cf615
3 changed files with 97 additions and 11 deletions

View File

@@ -25,7 +25,9 @@ from axolotl.loaders.model import ModelLoader
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
find_self_attn_in_layer,
get_attention_cls_from_config,
get_layers,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
@@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir):
break
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero():
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.1,
"lora_target_linear": True,
"sequence_len": 1024,
}
)
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
model
)
# Verify patch was not applied
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")