From 49746b184f5cae5e84a7243fd2dbbcae2365b6c4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 5 Feb 2025 19:17:57 +0700 Subject: [PATCH] chore: flatten directory structure and register to autoclass to save --- .../linear_llama/attention/__init__.py | 21 -------- .../configuration_linear_llama.py | 6 +++ .../linear_llama/{model => }/feature_map.py | 0 .../{attention => }/linear_attention.py | 4 +- .../linear_window_attention_sw.py | 0 .../linear_window_attention_sw_linear.py | 7 +-- .../linear_window_attention_sw_long.py | 0 .../linear_window_attention_tk.py | 0 .../linear_window_attention_tk_gen.py | 4 +- .../linear_window_attention_tk_long.py | 2 +- .../lolcats/linear_llama/model/__init__.py | 0 .../linear_llama/modeling_linear_llama.py | 53 +++++++++++-------- .../linear_llama/{model => }/rotary.py | 0 .../linear_llama/{attention => }/utils.py | 0 14 files changed, 44 insertions(+), 53 deletions(-) delete mode 100644 src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py rename src/axolotl/integrations/lolcats/linear_llama/{model => }/feature_map.py (100%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_attention.py (99%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_window_attention_sw.py (100%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_window_attention_sw_linear.py (99%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_window_attention_sw_long.py (100%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_window_attention_tk.py (100%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_window_attention_tk_gen.py (98%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/linear_window_attention_tk_long.py (99%) delete mode 100644 src/axolotl/integrations/lolcats/linear_llama/model/__init__.py rename src/axolotl/integrations/lolcats/linear_llama/{model => }/rotary.py (100%) rename src/axolotl/integrations/lolcats/linear_llama/{attention => }/utils.py (100%) diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py b/src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py deleted file mode 100644 index 3d815b7d0..000000000 --- a/src/axolotl/integrations/lolcats/linear_llama/attention/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -Linear and linear attention + sliding window classes -""" -from .linear_attention import LinearAttentionState, LolcatsLinearAttention -from .linear_window_attention_sw import ( - LinearAttentionSlidingWindowCache, - LolcatsSlidingWindowAttention, -) -from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention -from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention -from .linear_window_attention_tk import ( - LinearAttentionTKWindowCache, - LolcatsTKWindowAttention, -) -from .linear_window_attention_tk_gen import ( - LinearAttentionTKWindowGenerationCache, - LolcatsWindowAttentionTKGen, -) - -# Experimental chunk linear attentions -from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention diff --git a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py index 3997c134b..004056871 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/configuration_linear_llama.py @@ -64,6 +64,12 @@ class LinearLlamaConfig(LlamaConfig): def __init__(self, attention_config: Optional[dict] = None, **kwargs): super().__init__(**kwargs) + # self.auto_map = { + # "AutoConfig": "configuration_linear_llama.LinearLlamaConfig", + # "AutoModel": "modeling_linear_llama.LinearLlamaModel", + # "AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM", + # } + # Set default attention config if none provided self.attention_config = attention_config or {"attention_type": "softmax"} diff --git a/src/axolotl/integrations/lolcats/linear_llama/model/feature_map.py b/src/axolotl/integrations/lolcats/linear_llama/feature_map.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_llama/model/feature_map.py rename to src/axolotl/integrations/lolcats/linear_llama/feature_map.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_attention.py b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py similarity index 99% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_attention.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_attention.py index 0d77f1e44..2bd6afa5d 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_attention.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_attention.py @@ -15,8 +15,8 @@ try: except ImportError: fast_causal_dot_product = None -from ..model.feature_map import init_feature_map, init_learned_kernel -from ..model.rotary import apply_rotary_pos_emb +from .feature_map import init_feature_map, init_learned_kernel +from .rotary import apply_rotary_pos_emb from .utils import repeat_kv # ------------------- diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_linear.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py similarity index 99% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_linear.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py index 27c5db46c..9ea6c9a90 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_linear.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_linear.py @@ -23,18 +23,15 @@ try: except ModuleNotFoundError: _flash_attention_forward = None # Transformers v4.36 -from ..model.rotary import apply_rotary_pos_emb - # Causal linear attention dot product CUDA kernel from fast-transformers from .linear_attention import ( LinearAttentionState, LolcatsLinearAttention, causal_dot_product, ) +from .rotary import apply_rotary_pos_emb -LOG = logging.getLogger( - "axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long" -) +LOG = logging.getLogger(__name__) # ---------------------- diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_long.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_long.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_sw_long.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_sw_long.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_gen.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_gen.py similarity index 98% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_gen.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_gen.py index 9ac11acbf..fb7919bde 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_gen.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_gen.py @@ -11,9 +11,7 @@ import torch.nn.functional as F from .linear_attention import LinearAttentionState from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention -LOG = logging.getLogger( - "axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen" -) +LOG = logging.getLogger(__name__) try: from thunderkittens import hedgehog as tk_window_hedgehog_attention diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_long.py b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py similarity index 99% rename from src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_long.py rename to src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py index 25df64940..79ac8f21c 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/attention/linear_window_attention_tk_long.py +++ b/src/axolotl/integrations/lolcats/linear_llama/linear_window_attention_tk_long.py @@ -22,9 +22,9 @@ try: except ModuleNotFoundError: _flash_attention_forward = None # Transformers v4.36 -from ..model.rotary import apply_rotary_pos_emb from .linear_attention import softmax_attention from .linear_window_attention_tk import LolcatsTKWindowAttention +from .rotary import apply_rotary_pos_emb LOG = logging.getLogger( "axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long" diff --git a/src/axolotl/integrations/lolcats/linear_llama/model/__init__.py b/src/axolotl/integrations/lolcats/linear_llama/model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py index a3497b55e..abe29db5a 100644 --- a/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py +++ b/src/axolotl/integrations/lolcats/linear_llama/modeling_linear_llama.py @@ -11,7 +11,7 @@ import logging from functools import partial -from typing import Any +from typing import Any, Optional from torch import nn from tqdm import tqdm @@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import ( LlamaRotaryEmbedding, ) -from .attention import LolcatsLinearAttention from .configuration_linear_llama import LinearLlamaConfig LOG = logging.getLogger(__name__) @@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer): def __init__(self, config: LinearLlamaConfig, layer_idx: int): super().__init__(config, layer_idx) + # Replace the attention layer with our custom attention - self.self_attn = LolcatsLinearAttention( - base_attn=self.self_attn, # type: ignore - layer_idx=layer_idx, - **config.attention_config, + self.self_attn = convert_llama_attention( + layer=self, attention_config=config.attention_config ) @@ -229,7 +227,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False): def convert_llama_attention( layer: nn.Module, attention_config: dict, - layers: list[nn.Module], # list of layers + layers: Optional[list[nn.Module]] = None, # list of layers train_attention: bool = False, remove_base_attn: bool = True, ): @@ -239,7 +237,7 @@ def convert_llama_attention( return get_attention(**attention_config)( base_attn=layer.self_attn, layer_idx=layer.self_attn.layer_idx, # Transformers v4.36 - max_layer_idx=len(layers) - 1, + max_layer_idx=len(layers) - 1 if layers else None, train_attention=train_attention, remove_base_attn=remove_base_attn, ) @@ -254,39 +252,41 @@ def get_attention(attention_type: str, **kwargs): kwargs["attention_type"] = attention_type if attention_type == "lolcats_llama": - from .attention import LolcatsLinearAttention + from .linear_attention import LolcatsLinearAttention return partial(LolcatsLinearAttention, **kwargs) elif attention_type == "lolcats_llama_window_tk": - from .attention import LolcatsTKWindowAttention + from .linear_window_attention_tk import LolcatsTKWindowAttention return partial(LolcatsTKWindowAttention, **kwargs) elif attention_type == "lolcats_llama_window_sw": - from .attention import LolcatsSlidingWindowAttention + from .linear_window_attention_sw import LolcatsSlidingWindowAttention return partial(LolcatsSlidingWindowAttention, **kwargs) elif attention_type == "lolcats_llama_window_sw_linear": - from .attention import LolcatsLinearSlidingWindowAttention + from .linear_window_attention_sw_linear import ( + LolcatsLinearSlidingWindowAttention, + ) return partial(LolcatsLinearSlidingWindowAttention, **kwargs) # Experimental chunked linear attentions below elif attention_type == "lolcats_long_llama_window_tk": - from .attention import LolcatsTKWindowLongAttention + from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention return partial(LolcatsTKWindowLongAttention, **kwargs) elif attention_type == "lolcats_long_llama_window_sw": - from .attention import LolcatsSlidingWindowLongAttention + from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention return partial(LolcatsSlidingWindowLongAttention, **kwargs) # TK generation build (requires Thunderkittens) elif attention_type == "lolcats_llama_window_tk_gen": - from .attention import LolcatsWindowAttentionTKGen + from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen return partial(LolcatsWindowAttentionTKGen, **kwargs) @@ -304,28 +304,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None): # LOG.info(f'Returning attention cache based on attention_type == {attention_type}') elif "lolcats_llama_window_tk_gen" in attention_type: - from .attention import LinearAttentionTKWindowGenerationCache + from .linear_window_attention_tk_gen import ( + LinearAttentionTKWindowGenerationCache, + ) return LinearAttentionTKWindowGenerationCache() elif "llama_window_tk" in attention_type: - from .attention import LinearAttentionTKWindowCache + from .linear_window_attention_tk import LinearAttentionTKWindowCache return LinearAttentionTKWindowCache() elif "llama_window_sw" in attention_type: - from .attention import LinearAttentionSlidingWindowCache + from .linear_window_attention_sw import LinearAttentionSlidingWindowCache return LinearAttentionSlidingWindowCache() elif "llama_window_sw_linear" in attention_type: - from .attention import LinearAttentionSlidingWindowCache + from .linear_window_attention_sw import LinearAttentionSlidingWindowCache return LinearAttentionSlidingWindowCache() # TK generation build (requires Thunderkittens) elif attention_type == "lolcats_llama_window_tk_gen": - from .attention import LinearAttentionTKWindowGenerationCache + from .linear_window_attention_tk_gen import ( + LinearAttentionTKWindowGenerationCache, + ) return LinearAttentionTKWindowGenerationCache() @@ -333,7 +337,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None): return past_key_values else: - from .attention import LinearAttentionState + from .linear_attention import LinearAttentionState return LinearAttentionState() @@ -348,3 +352,10 @@ def register_linear_llama(): AutoConfig.register("linear_llama", LinearLlamaConfig) AutoModel.register(LinearLlamaConfig, LinearLlamaModel) AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM) + + # registering for auto classes to save files + LinearLlamaConfig.register_for_auto_class("AutoConfig") + LinearLlamaModel.register_for_auto_class("AutoModel") + LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM") + + print("registered transformers") diff --git a/src/axolotl/integrations/lolcats/linear_llama/model/rotary.py b/src/axolotl/integrations/lolcats/linear_llama/rotary.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_llama/model/rotary.py rename to src/axolotl/integrations/lolcats/linear_llama/rotary.py diff --git a/src/axolotl/integrations/lolcats/linear_llama/attention/utils.py b/src/axolotl/integrations/lolcats/linear_llama/utils.py similarity index 100% rename from src/axolotl/integrations/lolcats/linear_llama/attention/utils.py rename to src/axolotl/integrations/lolcats/linear_llama/utils.py