Lora kernels fix (#2732)
* fix lora kernel patching and improve test * simplification
This commit is contained in:
@@ -21,8 +21,11 @@ from axolotl.kernels.lora import (
|
||||
apply_lora_o,
|
||||
apply_lora_qkv,
|
||||
)
|
||||
from axolotl.loaders.model import ModelLoader
|
||||
from axolotl.loaders.tokenizer import load_tokenizer
|
||||
from axolotl.monkeypatch.lora_kernels import (
|
||||
apply_lora_kernel_patches,
|
||||
get_attention_cls_from_config,
|
||||
patch_self_attn_lora,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
@@ -80,7 +83,7 @@ def small_llama_model():
|
||||
)
|
||||
def test_attention_patching_integration(model_name, attention_cls):
|
||||
"""Test attention patching in integration context."""
|
||||
cfg = {"base_model": model_name}
|
||||
cfg = DictDefault({"base_model": model_name})
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(attention_cls, "forward")
|
||||
@@ -466,3 +469,35 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
||||
assert cfg.lora_mlp_kernel is True
|
||||
assert cfg.lora_qkv_kernel is True
|
||||
assert cfg.lora_o_kernel is True
|
||||
|
||||
# Get the attention class before patching to check for side effects
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
|
||||
# Store original state before patching
|
||||
original_forward_method = attention_cls.forward
|
||||
|
||||
# Load the model (this should trigger the patches)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
model, _ = ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
# Test side effects of patch_self_attn_lora
|
||||
assert hasattr(attention_cls, "_original_forward")
|
||||
assert attention_cls.forward != original_forward_method
|
||||
|
||||
# Find at least one self-attention module and verify it has the patched methods
|
||||
found_patched_attn = False
|
||||
for layer in model.model.model.layers:
|
||||
if hasattr(layer, "self_attn"):
|
||||
self_attn = layer.self_attn
|
||||
if all(
|
||||
hasattr(self_attn, proj)
|
||||
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
):
|
||||
# These methods should be added by apply_lora_kernel_patches
|
||||
assert hasattr(self_attn, "apply_qkv") and callable(self_attn.apply_qkv)
|
||||
assert hasattr(self_attn, "apply_o") and callable(self_attn.apply_o)
|
||||
|
||||
found_patched_attn = True
|
||||
break
|
||||
|
||||
assert found_patched_attn
|
||||
|
||||
Reference in New Issue
Block a user