Fix: lora kernel pre-patch applied despite post-patch not applied (#2772)
* fix: do not pre-patch self attention if lora dropout non-zero * fix: add test to check patch not applied * fix: test * fix: test config check * fix where we check so that tests don't break * fix: test --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -166,6 +166,17 @@ class PatchManager:
|
||||
def _apply_self_attention_lora_patch(self):
|
||||
"""Apply self-attention LoRA patches if configured."""
|
||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||
# Only patch if conditions are met
|
||||
can_patch = (
|
||||
self.cfg.lora_dropout == 0
|
||||
if hasattr(self.cfg, "lora_dropout")
|
||||
else True
|
||||
) # default to True if lora_dropout is not set
|
||||
|
||||
if not can_patch:
|
||||
LOG.warning("Cannot patch self-attention - requires no dropout")
|
||||
return
|
||||
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
@@ -274,6 +274,29 @@ def find_mlp_in_layer(
|
||||
)
|
||||
|
||||
|
||||
def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
||||
"""
|
||||
Get the layers of the model. Handles text-only and multimodal models.
|
||||
|
||||
Args:
|
||||
model: A PEFT model.
|
||||
|
||||
Returns:
|
||||
A list of layers.
|
||||
"""
|
||||
pretrained_model = model.model
|
||||
|
||||
# check for multimodal models first
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
return pretrained_model.language_model.layers
|
||||
if hasattr(pretrained_model, "model"):
|
||||
return pretrained_model.model.layers
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
model: PeftModelForCausalLM, cfg: DictDefault
|
||||
) -> PeftModelForCausalLM:
|
||||
@@ -345,17 +368,7 @@ def apply_lora_kernel_patches(
|
||||
if activation not in SUPPORTED_ACTIVATIONS:
|
||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||
|
||||
layers = []
|
||||
# check for multimodal models first
|
||||
pretrained_model = model.model
|
||||
if hasattr(pretrained_model, "language_model"):
|
||||
layers = pretrained_model.language_model.layers
|
||||
elif hasattr(pretrained_model, "model"):
|
||||
layers = pretrained_model.model.layers
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||
)
|
||||
layers = get_layers(model)
|
||||
|
||||
# Patch each layer
|
||||
for layer in layers:
|
||||
|
||||
@@ -25,7 +25,9 @@ from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
find_self_attn_in_layer,
|
||||
get_attention_cls_from_config,
|
||||
get_layers,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
break
|
||||
|
||||
assert found_patched_attn
|
||||
|
||||
|
||||
def test_kernel_training_integration_dropout_non_zero():
|
||||
"""Test model loading with dropout non-zero should not patch."""
|
||||
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
|
||||
# Create minimal config
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.1,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
}
|
||||
)
|
||||
|
||||
# Get original attention class
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load model
|
||||
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||
|
||||
# We call modelloader as that's where the patches are applied
|
||||
# despite the fact that we're not using it to load the model
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
|
||||
model
|
||||
)
|
||||
|
||||
# Verify patch was not applied
|
||||
layers = get_layers(model)
|
||||
for layer in layers:
|
||||
for self_attn in find_self_attn_in_layer(layer):
|
||||
assert not hasattr(self_attn, "apply_qkv")
|
||||
assert not hasattr(self_attn, "apply_o")
|
||||
|
||||
Reference in New Issue
Block a user