move patching to post-model load to improve applicability

This commit is contained in:
Dan Saunders
2025-02-18 19:00:12 +00:00
parent c3d4f6e295
commit 945dcc5020
3 changed files with 108 additions and 92 deletions

View File

@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
from peft import PeftModelForCausalLM
from torch import nn
from transformers import AutoConfig
from transformers.modeling_utils import PreTrainedModel
from axolotl.kernels.lora import (
apply_lora_mlp_geglu,
@@ -119,12 +120,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
model_config = AutoConfig.from_pretrained(cfg["base_model"])
model_type = model_config.model_type
# Special case for model_type = "qwen2"
if model_type == "qwen2":
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
return Qwen2Attention
try:
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
@@ -142,62 +137,86 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
# pylint: disable=protected-access
def patch_self_attn_lora(cfg: DictDefault):
def patch_self_attn_lora(model: PreTrainedModel):
"""
Given an `axolotl` config, this method patches the inferred attention class forward
pass with optimized LoRA implementations.
It modifies the attention class to use optimized QKV and output projections. The
original implementation is preserved and can be restored if needed.
Patches the attention classes in a transformer model with optimized LoRA implementations.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
model: A HuggingFace transformers model.
Raises:
AssertionError: If the required code blocks are not found in the attention
implementation.
"""
attention_cls = get_attention_cls_from_config(cfg)
# Find all attention modules in the model
attention_modules = [
module
for module in model.modules()
if "attention" in module.__class__.__name__.lower()
and hasattr(module, "forward")
]
# Check if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
if not attention_modules:
LOG.warning("No attention modules found in model")
return
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
self_attn_forward, _ = detab_code(self_attn_forward)
attention_classes = {type(module) for module in attention_modules}
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
for attention_cls in attention_classes:
# Skip if already patched
if hasattr(attention_cls, "_original_forward"):
LOG.info(f"{attention_cls.__name__} already patched")
continue
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
# Get and store original forward implementation
self_attn_forward = inspect.getsource(attention_cls.forward)
attention_cls._original_forward = self_attn_forward
# Load necessary imports
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
# Remove indentation
self_attn_forward, _ = detab_code(self_attn_forward)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
# Verify required code blocks exist
assert (
ORIGINAL_QKV_CODE in self_attn_forward
), f"Original QKV code not found in {attention_cls.__name__}"
assert (
ORIGINAL_O_CODE in self_attn_forward
), f"Original O code not found in {attention_cls.__name__}"
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
# Replace code blocks
self_attn_forward = self_attn_forward.replace(
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
)
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
self_attn_forward = self_attn_forward.replace(
"def forward(",
"def axolotl_attn_forward(",
1,
)
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
# Import necessary symbols from the attention module
module_name = attention_cls.__module__
module = importlib.import_module(module_name)
items_to_import = []
for item in dir(module):
if item in self_attn_forward:
items_to_import.append(item)
if items_to_import:
exec( # pylint: disable=exec-used # nosec B102
f"from {module_name} import ({', '.join(items_to_import)})",
globals(),
)
# Execute the new implementation
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
attention_cls.forward = (
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
)
def apply_lora_kernel_patches(

View File

@@ -439,11 +439,6 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.cfg)
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -1181,6 +1176,11 @@ class ModelLoader:
if self.cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
patch_self_attn_lora(self.model)
self.apply_unsloth_lora_patch()
self.apply_lora_patch()
@@ -1201,9 +1201,7 @@ def load_model(
reference_model: bool = False,
**kwargs, # pylint: disable=unused-argument
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
"""
Load a model for a given configuration and tokenizer.
"""
"""Load a model for a given configuration and tokenizer."""
loader = ModelLoader(
cfg,
tokenizer,