chore: flatten directory structure and register to autoclass to save

This commit is contained in:
NanoCode012
2025-02-05 19:17:57 +07:00
parent 9e1c4de13c
commit 49746b184f
14 changed files with 44 additions and 53 deletions

View File

@@ -1,21 +0,0 @@
"""
Linear and linear attention + sliding window classes
"""
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
from .linear_window_attention_sw import (
LinearAttentionSlidingWindowCache,
LolcatsSlidingWindowAttention,
)
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
from .linear_window_attention_tk import (
LinearAttentionTKWindowCache,
LolcatsTKWindowAttention,
)
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
LolcatsWindowAttentionTKGen,
)
# Experimental chunk linear attentions
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention

View File

@@ -64,6 +64,12 @@ class LinearLlamaConfig(LlamaConfig):
def __init__(self, attention_config: Optional[dict] = None, **kwargs): def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# self.auto_map = {
# "AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
# "AutoModel": "modeling_linear_llama.LinearLlamaModel",
# "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
# }
# Set default attention config if none provided # Set default attention config if none provided
self.attention_config = attention_config or {"attention_type": "softmax"} self.attention_config = attention_config or {"attention_type": "softmax"}

View File

@@ -15,8 +15,8 @@ try:
except ImportError: except ImportError:
fast_causal_dot_product = None fast_causal_dot_product = None
from ..model.feature_map import init_feature_map, init_learned_kernel from .feature_map import init_feature_map, init_learned_kernel
from ..model.rotary import apply_rotary_pos_emb from .rotary import apply_rotary_pos_emb
from .utils import repeat_kv from .utils import repeat_kv
# ------------------- # -------------------

View File

@@ -23,18 +23,15 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36 _flash_attention_forward = None # Transformers v4.36
from ..model.rotary import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers # Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import ( from .linear_attention import (
LinearAttentionState, LinearAttentionState,
LolcatsLinearAttention, LolcatsLinearAttention,
causal_dot_product, causal_dot_product,
) )
from .rotary import apply_rotary_pos_emb
LOG = logging.getLogger( LOG = logging.getLogger(__name__)
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
)
# ---------------------- # ----------------------

View File

@@ -11,9 +11,7 @@ import torch.nn.functional as F
from .linear_attention import LinearAttentionState from .linear_attention import LinearAttentionState
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
LOG = logging.getLogger( LOG = logging.getLogger(__name__)
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
)
try: try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention from thunderkittens import hedgehog as tk_window_hedgehog_attention

View File

@@ -22,9 +22,9 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36 _flash_attention_forward = None # Transformers v4.36
from ..model.rotary import apply_rotary_pos_emb
from .linear_attention import softmax_attention from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention from .linear_window_attention_tk import LolcatsTKWindowAttention
from .rotary import apply_rotary_pos_emb
LOG = logging.getLogger( LOG = logging.getLogger(
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long" "axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"

View File

@@ -11,7 +11,7 @@
import logging import logging
from functools import partial from functools import partial
from typing import Any from typing import Any, Optional
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
@@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
) )
from .attention import LolcatsLinearAttention
from .configuration_linear_llama import LinearLlamaConfig from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LinearLlamaConfig, layer_idx: int): def __init__(self, config: LinearLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx) super().__init__(config, layer_idx)
# Replace the attention layer with our custom attention # Replace the attention layer with our custom attention
self.self_attn = LolcatsLinearAttention( self.self_attn = convert_llama_attention(
base_attn=self.self_attn, # type: ignore layer=self, attention_config=config.attention_config
layer_idx=layer_idx,
**config.attention_config,
) )
@@ -229,7 +227,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False):
def convert_llama_attention( def convert_llama_attention(
layer: nn.Module, layer: nn.Module,
attention_config: dict, attention_config: dict,
layers: list[nn.Module], # list of layers layers: Optional[list[nn.Module]] = None, # list of layers
train_attention: bool = False, train_attention: bool = False,
remove_base_attn: bool = True, remove_base_attn: bool = True,
): ):
@@ -239,7 +237,7 @@ def convert_llama_attention(
return get_attention(**attention_config)( return get_attention(**attention_config)(
base_attn=layer.self_attn, base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36 layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1, max_layer_idx=len(layers) - 1 if layers else None,
train_attention=train_attention, train_attention=train_attention,
remove_base_attn=remove_base_attn, remove_base_attn=remove_base_attn,
) )
@@ -254,39 +252,41 @@ def get_attention(attention_type: str, **kwargs):
kwargs["attention_type"] = attention_type kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama": if attention_type == "lolcats_llama":
from .attention import LolcatsLinearAttention from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs) return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk": elif attention_type == "lolcats_llama_window_tk":
from .attention import LolcatsTKWindowAttention from .linear_window_attention_tk import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs) return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw": elif attention_type == "lolcats_llama_window_sw":
from .attention import LolcatsSlidingWindowAttention from .linear_window_attention_sw import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs) return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear": elif attention_type == "lolcats_llama_window_sw_linear":
from .attention import LolcatsLinearSlidingWindowAttention from .linear_window_attention_sw_linear import (
LolcatsLinearSlidingWindowAttention,
)
return partial(LolcatsLinearSlidingWindowAttention, **kwargs) return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below # Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk": elif attention_type == "lolcats_long_llama_window_tk":
from .attention import LolcatsTKWindowLongAttention from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs) return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw": elif attention_type == "lolcats_long_llama_window_sw":
from .attention import LolcatsSlidingWindowLongAttention from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs) return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens) # TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen": elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LolcatsWindowAttentionTKGen from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs) return partial(LolcatsWindowAttentionTKGen, **kwargs)
@@ -304,28 +304,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}') # LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type: elif "lolcats_llama_window_tk_gen" in attention_type:
from .attention import LinearAttentionTKWindowGenerationCache from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache() return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type: elif "llama_window_tk" in attention_type:
from .attention import LinearAttentionTKWindowCache from .linear_window_attention_tk import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache() return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type: elif "llama_window_sw" in attention_type:
from .attention import LinearAttentionSlidingWindowCache from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache() return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type: elif "llama_window_sw_linear" in attention_type:
from .attention import LinearAttentionSlidingWindowCache from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache() return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens) # TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen": elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LinearAttentionTKWindowGenerationCache from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache() return LinearAttentionTKWindowGenerationCache()
@@ -333,7 +337,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
return past_key_values return past_key_values
else: else:
from .attention import LinearAttentionState from .linear_attention import LinearAttentionState
return LinearAttentionState() return LinearAttentionState()
@@ -348,3 +352,10 @@ def register_linear_llama():
AutoConfig.register("linear_llama", LinearLlamaConfig) AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel) AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM) AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# registering for auto classes to save files
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
print("registered transformers")