support for qwen3 with lora kernels (#2588)

* support for qwen3 with lora kernels

* fix patch

* typo
This commit is contained in:
Wing Lian
2025-04-29 15:02:49 -04:00
parent f04f7cf5ad
commit c337ca0872
2 changed files with 56 additions and 19 deletions

View File

@@ -9,6 +9,7 @@ from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
@@ -66,29 +67,36 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration():
@pytest.mark.parametrize(
"model_name,attention_cls",
[
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
],
)
def test_attention_patching_integration(model_name, attention_cls):
"""Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
cfg = {"base_model": model_name}
# Store the original implementation
original_forward = getattr(LlamaAttention, "forward")
original_forward = getattr(attention_cls, "forward")
# Apply patch
patch_self_attn_lora(cfg)
# Get the new forward method
patched_forward = LlamaAttention.forward
patched_forward = attention_cls.forward
# Check the forward method was replaced
assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward")
assert hasattr(attention_cls, "_original_forward")
# Clean up
setattr(LlamaAttention, "forward", original_forward)
delattr(LlamaAttention, "_original_forward")
setattr(attention_cls, "forward", original_forward)
delattr(attention_cls, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model):