support for qwen3 with lora kernels (#2588)

* support for qwen3 with lora kernels

* fix patch

* typo
This commit is contained in:
Wing Lian
2025-04-29 15:02:49 -04:00
parent f04f7cf5ad
commit c337ca0872
2 changed files with 56 additions and 19 deletions

View File

@@ -23,22 +23,42 @@ from axolotl.utils.dict import DictDefault
LOG = get_logger(__name__)
ORIGINAL_QKV_CODE = """
QKV_PATCHES = [
(
"""
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
PATCHED_QKV_CODE = """
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = query_states.view(hidden_shape).transpose(1, 2)
key_states = key_states.view(hidden_shape).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
)
"\n"
),
),
(
"""
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
"""
query_states, key_states, value_states = self.apply_qkv(hidden_states)
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
value_states = value_states.view(hidden_shape).transpose(1, 2)
""".lstrip(
"\n"
),
),
]
ORIGINAL_O_CODE = """
attn_output = self.o_proj(attn_output)
@@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
module = __import__(
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
model_cls_prefix = "".join(
[part.capitalize() for part in model_type.split("_")]
)
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
return attention_cls
except (ImportError, AttributeError) as e:
@@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault):
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert any(
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
), "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
for qkv_orig, qkv_patched in QKV_PATCHES:
if qkv_orig in self_attn_forward:
self_attn_forward = self_attn_forward.replace(
qkv_orig,
qkv_patched,
)
break
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",