removing unused function
This commit is contained in:
@@ -4,13 +4,11 @@ import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import types
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers import AutoConfig
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
from axolotl.kernels.lora import (
|
||||
@@ -97,45 +95,6 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
|
||||
return attn_output
|
||||
|
||||
|
||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
||||
"""
|
||||
Get the appropriate attention class by inspecting the model config.
|
||||
Uses dynamic import to support any model architecture that follows
|
||||
the standard transformers naming convention.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
|
||||
Returns:
|
||||
The appropriate attention class for the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
||||
ImportError: If the model module or attention class doesn't exist
|
||||
"""
|
||||
if "base_model" not in cfg:
|
||||
raise ValueError("base_model must be specified in config")
|
||||
|
||||
# Get model config without loading the model
|
||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
||||
model_type = model_config.model_type
|
||||
|
||||
try:
|
||||
# Dynamically import the module and attention class
|
||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
||||
module = __import__(
|
||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
||||
)
|
||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
||||
|
||||
return attention_cls
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Could not import attention class for model_type: {model_type}. "
|
||||
f"Error: {str(e)}"
|
||||
) from e
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
def patch_self_attn_lora(model: PreTrainedModel):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user