chore: flatten directory structure and register to autoclass to save

This commit is contained in:
NanoCode012
2025-02-05 19:17:57 +07:00
parent 9e1c4de13c
commit 49746b184f
14 changed files with 44 additions and 53 deletions

View File

@@ -1,21 +0,0 @@
"""
Linear and linear attention + sliding window classes
"""
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
from .linear_window_attention_sw import (
LinearAttentionSlidingWindowCache,
LolcatsSlidingWindowAttention,
)
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
from .linear_window_attention_tk import (
LinearAttentionTKWindowCache,
LolcatsTKWindowAttention,
)
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
LolcatsWindowAttentionTKGen,
)
# Experimental chunk linear attentions
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention

View File

@@ -64,6 +64,12 @@ class LinearLlamaConfig(LlamaConfig):
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs)
# self.auto_map = {
# "AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
# "AutoModel": "modeling_linear_llama.LinearLlamaModel",
# "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
# }
# Set default attention config if none provided
self.attention_config = attention_config or {"attention_type": "softmax"}

View File

@@ -15,8 +15,8 @@ try:
except ImportError:
fast_causal_dot_product = None
from ..model.feature_map import init_feature_map, init_learned_kernel
from ..model.rotary import apply_rotary_pos_emb
from .feature_map import init_feature_map, init_learned_kernel
from .rotary import apply_rotary_pos_emb
from .utils import repeat_kv
# -------------------

View File

@@ -23,18 +23,15 @@ try:
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from ..model.rotary import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import (
LinearAttentionState,
LolcatsLinearAttention,
causal_dot_product,
)
from .rotary import apply_rotary_pos_emb
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
)
LOG = logging.getLogger(__name__)
# ----------------------

View File

@@ -11,9 +11,7 @@ import torch.nn.functional as F
from .linear_attention import LinearAttentionState
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
)
LOG = logging.getLogger(__name__)
try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention

View File

@@ -22,9 +22,9 @@ try:
except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36
from ..model.rotary import apply_rotary_pos_emb
from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention
from .rotary import apply_rotary_pos_emb
LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"

View File

@@ -11,7 +11,7 @@
import logging
from functools import partial
from typing import Any
from typing import Any, Optional
from torch import nn
from tqdm import tqdm
@@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding,
)
from .attention import LolcatsLinearAttention
from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__)
@@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx)
# Replace the attention layer with our custom attention
self.self_attn = LolcatsLinearAttention(
base_attn=self.self_attn, # type: ignore
layer_idx=layer_idx,
**config.attention_config,
self.self_attn = convert_llama_attention(
layer=self, attention_config=config.attention_config
)
@@ -229,7 +227,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False):
def convert_llama_attention(
layer: nn.Module,
attention_config: dict,
layers: list[nn.Module], # list of layers
layers: Optional[list[nn.Module]] = None, # list of layers
train_attention: bool = False,
remove_base_attn: bool = True,
):
@@ -239,7 +237,7 @@ def convert_llama_attention(
return get_attention(**attention_config)(
base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1,
max_layer_idx=len(layers) - 1 if layers else None,
train_attention=train_attention,
remove_base_attn=remove_base_attn,
)
@@ -254,39 +252,41 @@ def get_attention(attention_type: str, **kwargs):
kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama":
from .attention import LolcatsLinearAttention
from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk":
from .attention import LolcatsTKWindowAttention
from .linear_window_attention_tk import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw":
from .attention import LolcatsSlidingWindowAttention
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear":
from .attention import LolcatsLinearSlidingWindowAttention
from .linear_window_attention_sw_linear import (
LolcatsLinearSlidingWindowAttention,
)
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk":
from .attention import LolcatsTKWindowLongAttention
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw":
from .attention import LolcatsSlidingWindowLongAttention
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LolcatsWindowAttentionTKGen
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs)
@@ -304,28 +304,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type:
from .attention import LinearAttentionTKWindowGenerationCache
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type:
from .attention import LinearAttentionTKWindowCache
from .linear_window_attention_tk import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type:
from .attention import LinearAttentionSlidingWindowCache
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type:
from .attention import LinearAttentionSlidingWindowCache
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LinearAttentionTKWindowGenerationCache
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache()
@@ -333,7 +337,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
return past_key_values
else:
from .attention import LinearAttentionState
from .linear_attention import LinearAttentionState
return LinearAttentionState()
@@ -348,3 +352,10 @@ def register_linear_llama():
AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# registering for auto classes to save files
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
print("registered transformers")