support for qwen3 with lora kernels (#2588)

* support for qwen3 with lora kernels

* fix patch

* typo
This commit is contained in:
Wing Lian
2025-04-29 15:02:49 -04:00
parent f04f7cf5ad
commit c337ca0872
2 changed files with 56 additions and 19 deletions

View File

@@ -23,22 +23,42 @@ from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__) LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """ QKV_PATCHES = [
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip( """.lstrip(
"\n" "\n"
) ),
"""
PATCHED_QKV_CODE = """
query_states, key_states, value_states = self.apply_qkv(hidden_states) query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2) query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2) key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2) value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip( """.lstrip(
"\n" "\n"
) ),
),
(
"""
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
),
]
ORIGINAL_O_CODE = """ ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
@@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try: try:
# Dynamically import the module and attention class # Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}" module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__( model_cls_prefix = "".join(
module_path, fromlist=[f"{model_type.capitalize()}Attention"] [part.capitalize() for part in model_type.split("_")]
) )
attention_cls = getattr(module, f"{model_type.capitalize()}Attention") module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls return attention_cls
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
@@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault):
attention_cls._original_forward = self_attn_forward attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward) self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" assert any(
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
), "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found" assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(
qkv_orig,
qkv_patched,
)
break
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE) self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace( self_attn_forward = self_attn_forward.replace(
"def forward(", "def forward(",

View File

@@ -9,6 +9,7 @@ from peft import PeftModelForCausalLM, get_peft_config
from transformers import AutoModelForCausalLM, LlamaForCausalLM from transformers import AutoModelForCausalLM, LlamaForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
from axolotl.kernels.lora import ( from axolotl.kernels.lora import (
apply_lora_mlp_geglu, apply_lora_mlp_geglu,
@@ -66,29 +67,36 @@ def small_llama_model():
return LlamaForCausalLM(LlamaConfig(**config)) return LlamaForCausalLM(LlamaConfig(**config))
def test_attention_patching_integration(): @pytest.mark.parametrize(
"model_name,attention_cls",
[
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
],
)
def test_attention_patching_integration(model_name, attention_cls):
"""Test attention patching in integration context.""" """Test attention patching in integration context."""
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"} cfg = {"base_model": model_name}
# Store the original implementation # Store the original implementation
original_forward = getattr(LlamaAttention, "forward") original_forward = getattr(attention_cls, "forward")
# Apply patch # Apply patch
patch_self_attn_lora(cfg) patch_self_attn_lora(cfg)
# Get the new forward method # Get the new forward method
patched_forward = LlamaAttention.forward patched_forward = attention_cls.forward
# Check the forward method was replaced # Check the forward method was replaced
assert original_forward is not patched_forward assert original_forward is not patched_forward
assert patched_forward.__name__ == "axolotl_attn_forward" assert patched_forward.__name__ == "axolotl_attn_forward"
# Check original implementation was stored # Check original implementation was stored
assert hasattr(LlamaAttention, "_original_forward") assert hasattr(attention_cls, "_original_forward")
# Clean up # Clean up
setattr(LlamaAttention, "forward", original_forward) setattr(attention_cls, "forward", original_forward)
delattr(LlamaAttention, "_original_forward") delattr(attention_cls, "_original_forward")
def test_swiglu_mlp_integration(small_llama_model): def test_swiglu_mlp_integration(small_llama_model):