From 07e4f2e25b5d63ab84216ffb7c473cbe4b0b9582 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 29 Apr 2025 15:02:49 -0400 Subject: [PATCH] support for qwen3 with lora kernels (#2588) * support for qwen3 with lora kernels * fix patch * typo --- src/axolotl/monkeypatch/lora_kernels.py | 53 ++++++++++++++----- .../lora_kernels/test_lora_kernel_patching.py | 22 +++++--- 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 0036fe003..6c920dcc8 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -23,22 +23,42 @@ from axolotl.utils.dict import DictDefault LOG = get_logger(__name__) -ORIGINAL_QKV_CODE = """ +QKV_PATCHES = [ + ( + """ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) """.lstrip( - "\n" -) - -PATCHED_QKV_CODE = """ + "\n" + ), + """ query_states, key_states, value_states = self.apply_qkv(hidden_states) query_states = query_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2) """.lstrip( - "\n" -) + "\n" + ), + ), + ( + """ + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) +""".lstrip( + "\n" + ), + """ + query_states, key_states, value_states = self.apply_qkv(hidden_states) + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) + key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2) + value_states = value_states.view(hidden_shape).transpose(1, 2) +""".lstrip( + "\n" + ), + ), +] ORIGINAL_O_CODE = """ attn_output = self.o_proj(attn_output) @@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: try: # Dynamically import the module and attention class module_path = f"transformers.models.{model_type}.modeling_{model_type}" - module = __import__( - module_path, fromlist=[f"{model_type.capitalize()}Attention"] + model_cls_prefix = "".join( + [part.capitalize() for part in model_type.split("_")] ) - attention_cls = getattr(module, f"{model_type.capitalize()}Attention") + module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) + attention_cls = getattr(module, f"{model_cls_prefix}Attention") return attention_cls except (ImportError, AttributeError) as e: @@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault): attention_cls._original_forward = self_attn_forward self_attn_forward, _ = detab_code(self_attn_forward) - assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" + assert any( + qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES + ), "Original QKV code not found" assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" - self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) + for qkv_orig, qkv_patched in QKV_PATCHES: + if qkv_orig in self_attn_forward: + self_attn_forward = self_attn_forward.replace( + qkv_orig, + qkv_patched, + ) + break self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) self_attn_forward = self_attn_forward.replace( "def forward(", diff --git a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py index bada305b3..eb0c73225 100644 --- a/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py +++ b/tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py @@ -9,6 +9,7 @@ from peft import PeftModelForCausalLM, get_peft_config from transformers import AutoModelForCausalLM, LlamaForCausalLM from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaAttention +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention from axolotl.kernels.lora import ( apply_lora_mlp_geglu, @@ -66,29 +67,36 @@ def small_llama_model(): return LlamaForCausalLM(LlamaConfig(**config)) -def test_attention_patching_integration(): +@pytest.mark.parametrize( + "model_name,attention_cls", + [ + ("HuggingFaceTB/SmolLM2-135M", LlamaAttention), + ("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention), + ], +) +def test_attention_patching_integration(model_name, attention_cls): """Test attention patching in integration context.""" - cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"} + cfg = {"base_model": model_name} # Store the original implementation - original_forward = getattr(LlamaAttention, "forward") + original_forward = getattr(attention_cls, "forward") # Apply patch patch_self_attn_lora(cfg) # Get the new forward method - patched_forward = LlamaAttention.forward + patched_forward = attention_cls.forward # Check the forward method was replaced assert original_forward is not patched_forward assert patched_forward.__name__ == "axolotl_attn_forward" # Check original implementation was stored - assert hasattr(LlamaAttention, "_original_forward") + assert hasattr(attention_cls, "_original_forward") # Clean up - setattr(LlamaAttention, "forward", original_forward) - delattr(LlamaAttention, "_original_forward") + setattr(attention_cls, "forward", original_forward) + delattr(attention_cls, "_original_forward") def test_swiglu_mlp_integration(small_llama_model):