Fix: lora kernel pre-patch applied despite post-patch not applied (#2772)

* fix: do not pre-patch self attention if lora dropout non-zero

* fix: add test to check patch not applied

* fix: test

* fix: test config check

* fix where we check so that tests don't break

* fix: test

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
NanoCode012
2025-06-14 11:54:06 -07:00
committed by GitHub
parent 80d5b066ec
commit 21388cf615
3 changed files with 97 additions and 11 deletions

View File

@@ -166,6 +166,17 @@ class PatchManager:
def _apply_self_attention_lora_patch(self):
"""Apply self-attention LoRA patches if configured."""
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
# Only patch if conditions are met
can_patch = (
self.cfg.lora_dropout == 0
if hasattr(self.cfg, "lora_dropout")
else True
) # default to True if lora_dropout is not set
if not can_patch:
LOG.warning("Cannot patch self-attention - requires no dropout")
return
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)

View File

@@ -274,6 +274,29 @@ def find_mlp_in_layer(
)
def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
"""
Get the layers of the model. Handles text-only and multimodal models.
Args:
model: A PEFT model.
Returns:
A list of layers.
"""
pretrained_model = model.model
# check for multimodal models first
if hasattr(pretrained_model, "language_model"):
return pretrained_model.language_model.layers
if hasattr(pretrained_model, "model"):
return pretrained_model.model.layers
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
def apply_lora_kernel_patches(
model: PeftModelForCausalLM, cfg: DictDefault
) -> PeftModelForCausalLM:
@@ -345,17 +368,7 @@ def apply_lora_kernel_patches(
if activation not in SUPPORTED_ACTIVATIONS:
raise NotImplementedError(f"Activation {activation} is not supported")
layers = []
# check for multimodal models first
pretrained_model = model.model
if hasattr(pretrained_model, "language_model"):
layers = pretrained_model.language_model.layers
elif hasattr(pretrained_model, "model"):
layers = pretrained_model.model.layers
else:
raise NotImplementedError(
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
)
layers = get_layers(model)
# Patch each layer
for layer in layers:

View File

@@ -25,7 +25,9 @@ from axolotl.loaders.model import ModelLoader
from axolotl.loaders.tokenizer import load_tokenizer
from axolotl.monkeypatch.lora_kernels import (
apply_lora_kernel_patches,
find_self_attn_in_layer,
get_attention_cls_from_config,
get_layers,
patch_self_attn_lora,
)
from axolotl.utils.dict import DictDefault
@@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir):
break
assert found_patched_attn
def test_kernel_training_integration_dropout_non_zero():
"""Test model loading with dropout non-zero should not patch."""
from axolotl.cli.utils import load_model_and_tokenizer
# Create minimal config
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.1,
"lora_target_linear": True,
"sequence_len": 1024,
}
)
# Get original attention class
attention_cls = get_attention_cls_from_config(cfg)
# Store original state before patching
original_forward_method = attention_cls.forward
# Load model
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
# We call modelloader as that's where the patches are applied
# despite the fact that we're not using it to load the model
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
model
)
# Verify patch was not applied
layers = get_layers(model)
for layer in layers:
for self_attn in find_self_attn_in_layer(layer):
assert not hasattr(self_attn, "apply_qkv")
assert not hasattr(self_attn, "apply_o")