support for qwen3 with lora kernels (#2588)
* support for qwen3 with lora kernels * fix patch * typo
This commit is contained in:
@@ -9,6 +9,7 @@ from peft import PeftModelForCausalLM, get_peft_config
|
||||
from transformers import AutoModelForCausalLM, LlamaForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
@@ -66,29 +67,36 @@ def small_llama_model():
|
||||
return LlamaForCausalLM(LlamaConfig(**config))
|
||||
|
||||
|
||||
def test_attention_patching_integration():
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,attention_cls",
|
||||
[
|
||||
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
|
||||
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
|
||||
],
|
||||
)
|
||||
def test_attention_patching_integration(model_name, attention_cls):
|
||||
"""Test attention patching in integration context."""
|
||||
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
|
||||
cfg = {"base_model": model_name}
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(LlamaAttention, "forward")
|
||||
original_forward = getattr(attention_cls, "forward")
|
||||
|
||||
# Apply patch
|
||||
patch_self_attn_lora(cfg)
|
||||
|
||||
# Get the new forward method
|
||||
patched_forward = LlamaAttention.forward
|
||||
patched_forward = attention_cls.forward
|
||||
|
||||
# Check the forward method was replaced
|
||||
assert original_forward is not patched_forward
|
||||
assert patched_forward.__name__ == "axolotl_attn_forward"
|
||||
|
||||
# Check original implementation was stored
|
||||
assert hasattr(LlamaAttention, "_original_forward")
|
||||
assert hasattr(attention_cls, "_original_forward")
|
||||
|
||||
# Clean up
|
||||
setattr(LlamaAttention, "forward", original_forward)
|
||||
delattr(LlamaAttention, "_original_forward")
|
||||
setattr(attention_cls, "forward", original_forward)
|
||||
delattr(attention_cls, "_original_forward")
|
||||
|
||||
|
||||
def test_swiglu_mlp_integration(small_llama_model):
|
||||
|
||||
Reference in New Issue
Block a user