chore: flatten directory structure and register to autoclass to save
This commit is contained in:
@@ -1,21 +0,0 @@
|
|||||||
"""
|
|
||||||
Linear and linear attention + sliding window classes
|
|
||||||
"""
|
|
||||||
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
|
|
||||||
from .linear_window_attention_sw import (
|
|
||||||
LinearAttentionSlidingWindowCache,
|
|
||||||
LolcatsSlidingWindowAttention,
|
|
||||||
)
|
|
||||||
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
|
|
||||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
|
||||||
from .linear_window_attention_tk import (
|
|
||||||
LinearAttentionTKWindowCache,
|
|
||||||
LolcatsTKWindowAttention,
|
|
||||||
)
|
|
||||||
from .linear_window_attention_tk_gen import (
|
|
||||||
LinearAttentionTKWindowGenerationCache,
|
|
||||||
LolcatsWindowAttentionTKGen,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Experimental chunk linear attentions
|
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
|
||||||
@@ -64,6 +64,12 @@ class LinearLlamaConfig(LlamaConfig):
|
|||||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# self.auto_map = {
|
||||||
|
# "AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||||
|
# "AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||||
|
# "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||||
|
# }
|
||||||
|
|
||||||
# Set default attention config if none provided
|
# Set default attention config if none provided
|
||||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
fast_causal_dot_product = None
|
fast_causal_dot_product = None
|
||||||
|
|
||||||
from ..model.feature_map import init_feature_map, init_learned_kernel
|
from .feature_map import init_feature_map, init_learned_kernel
|
||||||
from ..model.rotary import apply_rotary_pos_emb
|
from .rotary import apply_rotary_pos_emb
|
||||||
from .utils import repeat_kv
|
from .utils import repeat_kv
|
||||||
|
|
||||||
# -------------------
|
# -------------------
|
||||||
@@ -23,18 +23,15 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
_flash_attention_forward = None # Transformers v4.36
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
from ..model.rotary import apply_rotary_pos_emb
|
|
||||||
|
|
||||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
from .linear_attention import (
|
from .linear_attention import (
|
||||||
LinearAttentionState,
|
LinearAttentionState,
|
||||||
LolcatsLinearAttention,
|
LolcatsLinearAttention,
|
||||||
causal_dot_product,
|
causal_dot_product,
|
||||||
)
|
)
|
||||||
|
from .rotary import apply_rotary_pos_emb
|
||||||
|
|
||||||
LOG = logging.getLogger(
|
LOG = logging.getLogger(__name__)
|
||||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------
|
# ----------------------
|
||||||
@@ -11,9 +11,7 @@ import torch.nn.functional as F
|
|||||||
from .linear_attention import LinearAttentionState
|
from .linear_attention import LinearAttentionState
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
LOG = logging.getLogger(
|
LOG = logging.getLogger(__name__)
|
||||||
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||||
@@ -22,9 +22,9 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
_flash_attention_forward = None # Transformers v4.36
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
from ..model.rotary import apply_rotary_pos_emb
|
|
||||||
from .linear_attention import softmax_attention
|
from .linear_attention import softmax_attention
|
||||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
from .rotary import apply_rotary_pos_emb
|
||||||
|
|
||||||
LOG = logging.getLogger(
|
LOG = logging.getLogger(
|
||||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||||
@@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
LlamaRotaryEmbedding,
|
LlamaRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .attention import LolcatsLinearAttention
|
|
||||||
from .configuration_linear_llama import LinearLlamaConfig
|
from .configuration_linear_llama import LinearLlamaConfig
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
|
|
||||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||||
super().__init__(config, layer_idx)
|
super().__init__(config, layer_idx)
|
||||||
|
|
||||||
# Replace the attention layer with our custom attention
|
# Replace the attention layer with our custom attention
|
||||||
self.self_attn = LolcatsLinearAttention(
|
self.self_attn = convert_llama_attention(
|
||||||
base_attn=self.self_attn, # type: ignore
|
layer=self, attention_config=config.attention_config
|
||||||
layer_idx=layer_idx,
|
|
||||||
**config.attention_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -229,7 +227,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False):
|
|||||||
def convert_llama_attention(
|
def convert_llama_attention(
|
||||||
layer: nn.Module,
|
layer: nn.Module,
|
||||||
attention_config: dict,
|
attention_config: dict,
|
||||||
layers: list[nn.Module], # list of layers
|
layers: Optional[list[nn.Module]] = None, # list of layers
|
||||||
train_attention: bool = False,
|
train_attention: bool = False,
|
||||||
remove_base_attn: bool = True,
|
remove_base_attn: bool = True,
|
||||||
):
|
):
|
||||||
@@ -239,7 +237,7 @@ def convert_llama_attention(
|
|||||||
return get_attention(**attention_config)(
|
return get_attention(**attention_config)(
|
||||||
base_attn=layer.self_attn,
|
base_attn=layer.self_attn,
|
||||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||||
max_layer_idx=len(layers) - 1,
|
max_layer_idx=len(layers) - 1 if layers else None,
|
||||||
train_attention=train_attention,
|
train_attention=train_attention,
|
||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
@@ -254,39 +252,41 @@ def get_attention(attention_type: str, **kwargs):
|
|||||||
kwargs["attention_type"] = attention_type
|
kwargs["attention_type"] = attention_type
|
||||||
|
|
||||||
if attention_type == "lolcats_llama":
|
if attention_type == "lolcats_llama":
|
||||||
from .attention import LolcatsLinearAttention
|
from .linear_attention import LolcatsLinearAttention
|
||||||
|
|
||||||
return partial(LolcatsLinearAttention, **kwargs)
|
return partial(LolcatsLinearAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_tk":
|
elif attention_type == "lolcats_llama_window_tk":
|
||||||
from .attention import LolcatsTKWindowAttention
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
|
||||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_sw":
|
elif attention_type == "lolcats_llama_window_sw":
|
||||||
from .attention import LolcatsSlidingWindowAttention
|
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
||||||
|
|
||||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||||
from .attention import LolcatsLinearSlidingWindowAttention
|
from .linear_window_attention_sw_linear import (
|
||||||
|
LolcatsLinearSlidingWindowAttention,
|
||||||
|
)
|
||||||
|
|
||||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
# Experimental chunked linear attentions below
|
# Experimental chunked linear attentions below
|
||||||
elif attention_type == "lolcats_long_llama_window_tk":
|
elif attention_type == "lolcats_long_llama_window_tk":
|
||||||
from .attention import LolcatsTKWindowLongAttention
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_long_llama_window_sw":
|
elif attention_type == "lolcats_long_llama_window_sw":
|
||||||
from .attention import LolcatsSlidingWindowLongAttention
|
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||||
|
|
||||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
# TK generation build (requires Thunderkittens)
|
# TK generation build (requires Thunderkittens)
|
||||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
from .attention import LolcatsWindowAttentionTKGen
|
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
||||||
|
|
||||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||||
|
|
||||||
@@ -304,28 +304,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
|||||||
|
|
||||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||||
from .attention import LinearAttentionTKWindowGenerationCache
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
return LinearAttentionTKWindowGenerationCache()
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
elif "llama_window_tk" in attention_type:
|
elif "llama_window_tk" in attention_type:
|
||||||
from .attention import LinearAttentionTKWindowCache
|
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
||||||
|
|
||||||
return LinearAttentionTKWindowCache()
|
return LinearAttentionTKWindowCache()
|
||||||
|
|
||||||
elif "llama_window_sw" in attention_type:
|
elif "llama_window_sw" in attention_type:
|
||||||
from .attention import LinearAttentionSlidingWindowCache
|
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
return LinearAttentionSlidingWindowCache()
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
elif "llama_window_sw_linear" in attention_type:
|
elif "llama_window_sw_linear" in attention_type:
|
||||||
from .attention import LinearAttentionSlidingWindowCache
|
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
return LinearAttentionSlidingWindowCache()
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
# TK generation build (requires Thunderkittens)
|
# TK generation build (requires Thunderkittens)
|
||||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
from .attention import LinearAttentionTKWindowGenerationCache
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
return LinearAttentionTKWindowGenerationCache()
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
@@ -333,7 +337,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
|||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .attention import LinearAttentionState
|
from .linear_attention import LinearAttentionState
|
||||||
|
|
||||||
return LinearAttentionState()
|
return LinearAttentionState()
|
||||||
|
|
||||||
@@ -348,3 +352,10 @@ def register_linear_llama():
|
|||||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||||
|
|
||||||
|
# registering for auto classes to save files
|
||||||
|
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||||
|
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||||
|
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||||
|
|
||||||
|
print("registered transformers")
|
||||||
|
|||||||
Reference in New Issue
Block a user