Fix: lora kernel pre-patch applied despite post-patch not applied (#2772)
* fix: do not pre-patch self attention if lora dropout non-zero * fix: add test to check patch not applied * fix: test * fix: test config check * fix where we check so that tests don't break * fix: test --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
@@ -166,6 +166,17 @@ class PatchManager:
|
|||||||
def _apply_self_attention_lora_patch(self):
|
def _apply_self_attention_lora_patch(self):
|
||||||
"""Apply self-attention LoRA patches if configured."""
|
"""Apply self-attention LoRA patches if configured."""
|
||||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||||
|
# Only patch if conditions are met
|
||||||
|
can_patch = (
|
||||||
|
self.cfg.lora_dropout == 0
|
||||||
|
if hasattr(self.cfg, "lora_dropout")
|
||||||
|
else True
|
||||||
|
) # default to True if lora_dropout is not set
|
||||||
|
|
||||||
|
if not can_patch:
|
||||||
|
LOG.warning("Cannot patch self-attention - requires no dropout")
|
||||||
|
return
|
||||||
|
|
||||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||||
|
|
||||||
patch_self_attn_lora(self.cfg)
|
patch_self_attn_lora(self.cfg)
|
||||||
|
|||||||
@@ -274,6 +274,29 @@ def find_mlp_in_layer(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_layers(model: PeftModelForCausalLM) -> list[nn.Module]:
|
||||||
|
"""
|
||||||
|
Get the layers of the model. Handles text-only and multimodal models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A PEFT model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of layers.
|
||||||
|
"""
|
||||||
|
pretrained_model = model.model
|
||||||
|
|
||||||
|
# check for multimodal models first
|
||||||
|
if hasattr(pretrained_model, "language_model"):
|
||||||
|
return pretrained_model.language_model.layers
|
||||||
|
if hasattr(pretrained_model, "model"):
|
||||||
|
return pretrained_model.model.layers
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def apply_lora_kernel_patches(
|
def apply_lora_kernel_patches(
|
||||||
model: PeftModelForCausalLM, cfg: DictDefault
|
model: PeftModelForCausalLM, cfg: DictDefault
|
||||||
) -> PeftModelForCausalLM:
|
) -> PeftModelForCausalLM:
|
||||||
@@ -345,17 +368,7 @@ def apply_lora_kernel_patches(
|
|||||||
if activation not in SUPPORTED_ACTIVATIONS:
|
if activation not in SUPPORTED_ACTIVATIONS:
|
||||||
raise NotImplementedError(f"Activation {activation} is not supported")
|
raise NotImplementedError(f"Activation {activation} is not supported")
|
||||||
|
|
||||||
layers = []
|
layers = get_layers(model)
|
||||||
# check for multimodal models first
|
|
||||||
pretrained_model = model.model
|
|
||||||
if hasattr(pretrained_model, "language_model"):
|
|
||||||
layers = pretrained_model.language_model.layers
|
|
||||||
elif hasattr(pretrained_model, "model"):
|
|
||||||
layers = pretrained_model.model.layers
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Model type {model.config.model_type} is not supported yet. Please create an Issue."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Patch each layer
|
# Patch each layer
|
||||||
for layer in layers:
|
for layer in layers:
|
||||||
|
|||||||
@@ -25,7 +25,9 @@ from axolotl.loaders.model import ModelLoader
|
|||||||
from axolotl.loaders.tokenizer import load_tokenizer
|
from axolotl.loaders.tokenizer import load_tokenizer
|
||||||
from axolotl.monkeypatch.lora_kernels import (
|
from axolotl.monkeypatch.lora_kernels import (
|
||||||
apply_lora_kernel_patches,
|
apply_lora_kernel_patches,
|
||||||
|
find_self_attn_in_layer,
|
||||||
get_attention_cls_from_config,
|
get_attention_cls_from_config,
|
||||||
|
get_layers,
|
||||||
patch_self_attn_lora,
|
patch_self_attn_lora,
|
||||||
)
|
)
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
@@ -501,3 +503,63 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert found_patched_attn
|
assert found_patched_attn
|
||||||
|
|
||||||
|
|
||||||
|
def test_kernel_training_integration_dropout_non_zero():
|
||||||
|
"""Test model loading with dropout non-zero should not patch."""
|
||||||
|
|
||||||
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
|
# Create minimal config
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
|
"learning_rate": 0.000001,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"micro_batch_size": 1,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"adapter": "lora",
|
||||||
|
"lora_r": 8,
|
||||||
|
"lora_alpha": 16,
|
||||||
|
"lora_dropout": 0.1,
|
||||||
|
"lora_target_linear": True,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get original attention class
|
||||||
|
attention_cls = get_attention_cls_from_config(cfg)
|
||||||
|
|
||||||
|
# Store original state before patching
|
||||||
|
original_forward_method = attention_cls.forward
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
model, tokenizer, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
|
# We call modelloader as that's where the patches are applied
|
||||||
|
# despite the fact that we're not using it to load the model
|
||||||
|
model_loader = ModelLoader(cfg, tokenizer)
|
||||||
|
|
||||||
|
# Apply patch
|
||||||
|
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
|
||||||
|
|
||||||
|
# Verify patch was not applied
|
||||||
|
assert attention_cls.forward == original_forward_method
|
||||||
|
|
||||||
|
# Apply apply_lora_kernel_patches
|
||||||
|
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
|
||||||
|
model
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify patch was not applied
|
||||||
|
layers = get_layers(model)
|
||||||
|
for layer in layers:
|
||||||
|
for self_attn in find_self_attn_in_layer(layer):
|
||||||
|
assert not hasattr(self_attn, "apply_qkv")
|
||||||
|
assert not hasattr(self_attn, "apply_o")
|
||||||
|
|||||||
Reference in New Issue
Block a user