chore: flatten directory structure and register to autoclass to save
This commit is contained in:
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
Linear and linear attention + sliding window classes
|
||||
"""
|
||||
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
|
||||
from .linear_window_attention_sw import (
|
||||
LinearAttentionSlidingWindowCache,
|
||||
LolcatsSlidingWindowAttention,
|
||||
)
|
||||
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
|
||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||
from .linear_window_attention_tk import (
|
||||
LinearAttentionTKWindowCache,
|
||||
LolcatsTKWindowAttention,
|
||||
)
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
LolcatsWindowAttentionTKGen,
|
||||
)
|
||||
|
||||
# Experimental chunk linear attentions
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
@@ -64,6 +64,12 @@ class LinearLlamaConfig(LlamaConfig):
|
||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# self.auto_map = {
|
||||
# "AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||
# "AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||
# "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||
# }
|
||||
|
||||
# Set default attention config if none provided
|
||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||
|
||||
|
||||
@@ -15,8 +15,8 @@ try:
|
||||
except ImportError:
|
||||
fast_causal_dot_product = None
|
||||
|
||||
from ..model.feature_map import init_feature_map, init_learned_kernel
|
||||
from ..model.rotary import apply_rotary_pos_emb
|
||||
from .feature_map import init_feature_map, init_learned_kernel
|
||||
from .rotary import apply_rotary_pos_emb
|
||||
from .utils import repeat_kv
|
||||
|
||||
# -------------------
|
||||
@@ -23,18 +23,15 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from ..model.rotary import apply_rotary_pos_emb
|
||||
|
||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||
from .linear_attention import (
|
||||
LinearAttentionState,
|
||||
LolcatsLinearAttention,
|
||||
causal_dot_product,
|
||||
)
|
||||
from .rotary import apply_rotary_pos_emb
|
||||
|
||||
LOG = logging.getLogger(
|
||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
|
||||
)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------
|
||||
@@ -11,9 +11,7 @@ import torch.nn.functional as F
|
||||
from .linear_attention import LinearAttentionState
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
LOG = logging.getLogger(
|
||||
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
|
||||
)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||
@@ -22,9 +22,9 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
_flash_attention_forward = None # Transformers v4.36
|
||||
|
||||
from ..model.rotary import apply_rotary_pos_emb
|
||||
from .linear_attention import softmax_attention
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
from .rotary import apply_rotary_pos_emb
|
||||
|
||||
LOG = logging.getLogger(
|
||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||
@@ -11,7 +11,7 @@
|
||||
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
@@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import (
|
||||
LlamaRotaryEmbedding,
|
||||
)
|
||||
|
||||
from .attention import LolcatsLinearAttention
|
||||
from .configuration_linear_llama import LinearLlamaConfig
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
@@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||
super().__init__(config, layer_idx)
|
||||
|
||||
# Replace the attention layer with our custom attention
|
||||
self.self_attn = LolcatsLinearAttention(
|
||||
base_attn=self.self_attn, # type: ignore
|
||||
layer_idx=layer_idx,
|
||||
**config.attention_config,
|
||||
self.self_attn = convert_llama_attention(
|
||||
layer=self, attention_config=config.attention_config
|
||||
)
|
||||
|
||||
|
||||
@@ -229,7 +227,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||
def convert_llama_attention(
|
||||
layer: nn.Module,
|
||||
attention_config: dict,
|
||||
layers: list[nn.Module], # list of layers
|
||||
layers: Optional[list[nn.Module]] = None, # list of layers
|
||||
train_attention: bool = False,
|
||||
remove_base_attn: bool = True,
|
||||
):
|
||||
@@ -239,7 +237,7 @@ def convert_llama_attention(
|
||||
return get_attention(**attention_config)(
|
||||
base_attn=layer.self_attn,
|
||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||
max_layer_idx=len(layers) - 1,
|
||||
max_layer_idx=len(layers) - 1 if layers else None,
|
||||
train_attention=train_attention,
|
||||
remove_base_attn=remove_base_attn,
|
||||
)
|
||||
@@ -254,39 +252,41 @@ def get_attention(attention_type: str, **kwargs):
|
||||
kwargs["attention_type"] = attention_type
|
||||
|
||||
if attention_type == "lolcats_llama":
|
||||
from .attention import LolcatsLinearAttention
|
||||
from .linear_attention import LolcatsLinearAttention
|
||||
|
||||
return partial(LolcatsLinearAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_tk":
|
||||
from .attention import LolcatsTKWindowAttention
|
||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||
|
||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw":
|
||||
from .attention import LolcatsSlidingWindowAttention
|
||||
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||
from .attention import LolcatsLinearSlidingWindowAttention
|
||||
from .linear_window_attention_sw_linear import (
|
||||
LolcatsLinearSlidingWindowAttention,
|
||||
)
|
||||
|
||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||
|
||||
# Experimental chunked linear attentions below
|
||||
elif attention_type == "lolcats_long_llama_window_tk":
|
||||
from .attention import LolcatsTKWindowLongAttention
|
||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||
|
||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||
|
||||
elif attention_type == "lolcats_long_llama_window_sw":
|
||||
from .attention import LolcatsSlidingWindowLongAttention
|
||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||
|
||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .attention import LolcatsWindowAttentionTKGen
|
||||
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
||||
|
||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||
|
||||
@@ -304,28 +304,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
|
||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||
from .attention import LinearAttentionTKWindowGenerationCache
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
elif "llama_window_tk" in attention_type:
|
||||
from .attention import LinearAttentionTKWindowCache
|
||||
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
||||
|
||||
return LinearAttentionTKWindowCache()
|
||||
|
||||
elif "llama_window_sw" in attention_type:
|
||||
from .attention import LinearAttentionSlidingWindowCache
|
||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
elif "llama_window_sw_linear" in attention_type:
|
||||
from .attention import LinearAttentionSlidingWindowCache
|
||||
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||
|
||||
return LinearAttentionSlidingWindowCache()
|
||||
|
||||
# TK generation build (requires Thunderkittens)
|
||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||
from .attention import LinearAttentionTKWindowGenerationCache
|
||||
from .linear_window_attention_tk_gen import (
|
||||
LinearAttentionTKWindowGenerationCache,
|
||||
)
|
||||
|
||||
return LinearAttentionTKWindowGenerationCache()
|
||||
|
||||
@@ -333,7 +337,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||
return past_key_values
|
||||
|
||||
else:
|
||||
from .attention import LinearAttentionState
|
||||
from .linear_attention import LinearAttentionState
|
||||
|
||||
return LinearAttentionState()
|
||||
|
||||
@@ -348,3 +352,10 @@ def register_linear_llama():
|
||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||
|
||||
# registering for auto classes to save files
|
||||
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||
|
||||
print("registered transformers")
|
||||
|
||||
Reference in New Issue
Block a user