diff --git a/src/axolotl/monkeypatch/lora_kernels.py b/src/axolotl/monkeypatch/lora_kernels.py index 0daf16a29..5fa66f5a4 100644 --- a/src/axolotl/monkeypatch/lora_kernels.py +++ b/src/axolotl/monkeypatch/lora_kernels.py @@ -4,13 +4,11 @@ import importlib import inspect import logging import types -from typing import Type import torch from accelerate.logging import get_logger from peft import PeftModelForCausalLM from torch import nn -from transformers import AutoConfig from transformers.modeling_utils import PreTrainedModel from axolotl.kernels.lora import ( @@ -97,45 +95,6 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens return attn_output -def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]: - """ - Get the appropriate attention class by inspecting the model config. - Uses dynamic import to support any model architecture that follows - the standard transformers naming convention. - - Args: - cfg: Dictionary mapping `axolotl` config keys to values. - - Returns: - The appropriate attention class for the model. - - Raises: - ValueError: If `base_model` not specified or attention class cannot be imported - ImportError: If the model module or attention class doesn't exist - """ - if "base_model" not in cfg: - raise ValueError("base_model must be specified in config") - - # Get model config without loading the model - model_config = AutoConfig.from_pretrained(cfg["base_model"]) - model_type = model_config.model_type - - try: - # Dynamically import the module and attention class - module_path = f"transformers.models.{model_type}.modeling_{model_type}" - module = __import__( - module_path, fromlist=[f"{model_type.capitalize()}Attention"] - ) - attention_cls = getattr(module, f"{model_type.capitalize()}Attention") - - return attention_cls - except (ImportError, AttributeError) as e: - raise ValueError( - f"Could not import attention class for model_type: {model_type}. " - f"Error: {str(e)}" - ) from e - - # pylint: disable=protected-access def patch_self_attn_lora(model: PreTrainedModel): """