move patching to post-model load to improve applicability
This commit is contained in:
@@ -11,6 +11,7 @@ from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
apply_lora_mlp_geglu,
|
||||
@@ -119,12 +120,6 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
# Special case for model_type = "qwen2"
|
||||
if model_type == "qwen2":
|
||||
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention
|
||||
|
||||
return Qwen2Attention
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
@@ -142,62 +137,86 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(cfg: DictDefault):
|
||||
def patch_self_attn_lora(model: PreTrainedModel):
|
||||
"""
|
||||
Given an `axolotl` config, this method patches the inferred attention class forward
|
||||
pass with optimized LoRA implementations.
|
||||
|
||||
It modifies the attention class to use optimized QKV and output projections. The
|
||||
original implementation is preserved and can be restored if needed.
|
||||
Patches the attention classes in a transformer model with optimized LoRA implementations.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
model: A HuggingFace transformers model.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the required code blocks are not found in the attention
|
||||
implementation.
|
||||
"""
|
||||
attention_cls = get_attention_cls_from_config(cfg)
|
||||
# Find all attention modules in the model
|
||||
attention_modules = [
|
||||
module
|
||||
for module in model.modules()
|
||||
if "attention" in module.__class__.__name__.lower()
|
||||
and hasattr(module, "forward")
|
||||
]
|
||||
|
||||
# Check if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
if not attention_modules:
|
||||
LOG.warning("No attention modules found in model")
|
||||
return
|
||||
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
attention_classes = {type(module) for module in attention_modules}
|
||||
LOG.info(f"Found attention classes: {[cls.__name__ for cls in attention_classes]}")
|
||||
|
||||
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
|
||||
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
||||
for attention_cls in attention_classes:
|
||||
# Skip if already patched
|
||||
if hasattr(attention_cls, "_original_forward"):
|
||||
LOG.info(f"{attention_cls.__name__} already patched")
|
||||
continue
|
||||
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
# Get and store original forward implementation
|
||||
self_attn_forward = inspect.getsource(attention_cls.forward)
|
||||
attention_cls._original_forward = self_attn_forward
|
||||
|
||||
# Load necessary imports
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
# Remove indentation
|
||||
self_attn_forward, _ = detab_code(self_attn_forward)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
# Verify required code blocks exist
|
||||
assert (
|
||||
ORIGINAL_QKV_CODE in self_attn_forward
|
||||
), f"Original QKV code not found in {attention_cls.__name__}"
|
||||
assert (
|
||||
ORIGINAL_O_CODE in self_attn_forward
|
||||
), f"Original O code not found in {attention_cls.__name__}"
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
# Replace code blocks
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
ORIGINAL_QKV_CODE, PATCHED_QKV_CODE
|
||||
)
|
||||
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
||||
self_attn_forward = self_attn_forward.replace(
|
||||
"def forward(",
|
||||
"def axolotl_attn_forward(",
|
||||
1,
|
||||
)
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
# Import necessary symbols from the attention module
|
||||
module_name = attention_cls.__module__
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(module):
|
||||
if item in self_attn_forward:
|
||||
items_to_import.append(item)
|
||||
|
||||
if items_to_import:
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
f"from {module_name} import ({', '.join(items_to_import)})",
|
||||
globals(),
|
||||
)
|
||||
|
||||
# Execute the new implementation
|
||||
exec(self_attn_forward, globals()) # pylint: disable=exec-used # nosec B102
|
||||
|
||||
LOG.info(f"Patched attention class with LoRA optims: {attention_cls.__name__}")
|
||||
attention_cls.forward = (
|
||||
axolotl_attn_forward # pylint: disable=undefined-variable # noqa: F821
|
||||
)
|
||||
|
||||
|
||||
def apply_lora_kernel_patches(
|
||||
|
||||
@@ -439,11 +439,6 @@ class ModelLoader:
|
||||
|
||||
patch_mistral_cross_entropy()
|
||||
|
||||
if self.cfg.unsloth_lora_qkv or self.cfg.unsloth_lora_o:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.cfg)
|
||||
|
||||
def patch_attention(self) -> None:
|
||||
if hasattr(self.model_config, "model_type"):
|
||||
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
|
||||
@@ -1181,6 +1176,11 @@ class ModelLoader:
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
|
||||
if self.cfg.lora_qkv_kernel or self.cfg.lora_o_kernel:
|
||||
from axolotl.monkeypatch.lora_kernels import patch_self_attn_lora
|
||||
|
||||
patch_self_attn_lora(self.model)
|
||||
|
||||
self.apply_unsloth_lora_patch()
|
||||
self.apply_lora_patch()
|
||||
|
||||
@@ -1201,9 +1201,7 @@ def load_model(
|
||||
reference_model: bool = False,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
|
||||
"""
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
"""Load a model for a given configuration and tokenizer."""
|
||||
loader = ModelLoader(
|
||||
cfg,
|
||||
tokenizer,
|
||||
|
||||
Reference in New Issue
Block a user