removing unused function
This commit is contained in:
@@ -4,13 +4,11 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import types
|
import types
|
||||||
from typing import Type
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from peft import PeftModelForCausalLM
|
from peft import PeftModelForCausalLM
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoConfig
|
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
from transformers.modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from axolotl.kernels.lora import (
|
from axolotl.kernels.lora import (
|
||||||
@@ -97,45 +95,6 @@ def original_apply_o(self: nn.Module, hidden_states: torch.Tensor) -> torch.Tens
|
|||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
|
|
||||||
def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
|
||||||
"""
|
|
||||||
Get the appropriate attention class by inspecting the model config.
|
|
||||||
Uses dynamic import to support any model architecture that follows
|
|
||||||
the standard transformers naming convention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The appropriate attention class for the model.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If `base_model` not specified or attention class cannot be imported
|
|
||||||
ImportError: If the model module or attention class doesn't exist
|
|
||||||
"""
|
|
||||||
if "base_model" not in cfg:
|
|
||||||
raise ValueError("base_model must be specified in config")
|
|
||||||
|
|
||||||
# Get model config without loading the model
|
|
||||||
model_config = AutoConfig.from_pretrained(cfg["base_model"])
|
|
||||||
model_type = model_config.model_type
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Dynamically import the module and attention class
|
|
||||||
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
|
||||||
module = __import__(
|
|
||||||
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
|
|
||||||
)
|
|
||||||
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
|
|
||||||
|
|
||||||
return attention_cls
|
|
||||||
except (ImportError, AttributeError) as e:
|
|
||||||
raise ValueError(
|
|
||||||
f"Could not import attention class for model_type: {model_type}. "
|
|
||||||
f"Error: {str(e)}"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def patch_self_attn_lora(model: PreTrainedModel):
|
def patch_self_attn_lora(model: PreTrainedModel):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user