diff --git a/src/axolotl/integrations/lolcats/LICENSE b/src/axolotl/integrations/lolcats/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/src/axolotl/integrations/lolcats/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/axolotl/integrations/lolcats/linear_attention/__init__.py b/src/axolotl/integrations/lolcats/linear_attention/__init__.py new file mode 100644 index 000000000..9f94414d6 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/__init__.py @@ -0,0 +1,20 @@ +""" +Linear and linear attention + sliding window classes +""" +from .linear_attention import LinearAttentionState, LolcatsLinearAttention +from .linear_window_attention_sw import ( + LinearAttentionSlidingWindowCache, + LolcatsSlidingWindowAttention, +) +from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention +from .linear_window_attention_tk import ( + LinearAttentionTKWindowCache, + LolcatsTKWindowAttention, +) +from .linear_window_attention_tk_gen import ( + LinearAttentionTKWindowGenerationCache, + LolcatsWindowAttentionTKGen, +) + +# Experimental chunk linear attentions +from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py b/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py new file mode 100644 index 000000000..abce4dd93 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_attention.py @@ -0,0 +1,561 @@ +""" +Linear attention classes +""" + +import copy +from typing import Any, List, Optional, Tuple + +import torch +import torch.nn as nn +from transformers.cache_utils import Cache + +# Causal linear attention dot product CUDA kernel from fast-transformers +try: + from csrc import causal_dot_product as fast_causal_dot_product +except ImportError: + fast_causal_dot_product = None + +from ..model.feature_map import init_feature_map, init_learned_kernel +from ..model.rotary import apply_rotary_pos_emb, get_rotary_embeddings +from .utils import repeat_kv + +# ------------------- +# Attention functions +# ------------------- + + +def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + Causal linear attention dot product + - If available, use CUDA kernel from fast-transformers + """ + if fast_causal_dot_product is None: + kv = torch.einsum("bhlf,bhld->bhlfd", k, v) + return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2)) + return fast_causal_dot_product(q, k, v) + + +def linear_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + fp32_attention: bool = False, + eps: float = 1e-12, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Compute linear attention with CUDA kernel implementation from fast-transformers + - https://github.com/idiap/fast-transformers + - Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); + v is shape (b, h, l, head_dim) + """ + dtype = q.dtype + # Causal mask already applied + y = causal_dot_product( + q.contiguous().to(dtype=torch.float32), + k.contiguous().to(dtype=torch.float32), + v.contiguous().to(dtype=torch.float32), + ) + if fp32_attention: + y = ( + y + / ( + torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps + )[..., None] + ).to(dtype=dtype) + else: + y = y.to(dtype=dtype) + k = k.float().cumsum(dim=2).to(dtype=dtype) + y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None] + return y, None, None + + +def softmax_attention( + q: torch.Tensor, + k: torch.Tensor, + v: Optional[torch.Tensor] = None, + causal: bool = True, + fp32_attention: bool = True, +): + """ + Standard softmax attention; only compute outputs if v is not None + -> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim) + """ + y = None + a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5) + if causal: # Apply causal mask + m, n = a.shape[-2:] + causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu( + n - m + 1 + ) + a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max) + if fp32_attention: + a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype) + else: + a = torch.softmax(a, dim=-1) + if v is not None: + y = torch.einsum("bhmn,bhnd->bhmd", a, v) + return y, a, None + + +def quadratic_attention( + q: torch.Tensor, + k: torch.Tensor, + v: Optional[torch.Tensor] = None, + causal: bool = True, + fp32_attention: bool = False, + eps: float = 1e-12, +): + """ + Compute attention with feature maps by instantiating L x L matrix of attention weights + -> Use for attention distillation + -> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim) + """ + y = None + dtype = q.dtype + if fp32_attention: + q, k = q.float(), k.float() + a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could + if causal: # Apply causal mask + m, n = a.shape[-2:] + causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu( + n - m + 1 + ) + a = a.masked_fill(causal_mask, 0) + # Normalize to compute attention + a = a / (a.sum(dim=-1, keepdim=True) + eps) + a = a.to(dtype=dtype) if fp32_attention else a + if torch.isnan(a).sum() > 0: + breakpoint() + if v is not None: + y = torch.einsum("bhmn,bhnd->bhmd", a, v) + return y, a, None + + +# --------------------- +# Attention layer class +# --------------------- + + +class LolcatsLinearAttention(nn.Module): + """ + LoLCATs attention implementation initialized from a + `LlamaAttention` or `MistralAttention` object (base_attn) + + Most of the arguments are directly tied to argparse args + - For now we don't support padding. + """ + + def __init__( + self, + base_attn: nn.Module, # like LlamaAttention + feature_map: str, + feature_map_kwargs: dict, + layer_idx: Optional[int] = None, + max_layer_idx: Optional[int] = None, + learned_kernel: Optional[str] = None, + learned_kernel_kwargs: Optional[dict] = None, + tie_qk_kernels: Optional[bool] = False, + rotary_config: Optional[dict] = None, + train_attention: Optional[bool] = False, + remove_base_attn: bool = True, + attention_type: Optional[str] = "lolcats_llama", + mask_value: int = 0, + eps: float = 1e-12, + fp32_attention: bool = False, + track_state_grads: bool = False, + rank: Optional[int] = 0, + **kwargs, + ) -> None: + super().__init__() + self.base_config = getattr(base_attn, "config", None) + if self.base_config is not None: + self.base_config = self.base_config.to_dict() + self.attention_type = attention_type + self.mask_value = mask_value + self.eps = eps + self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx + self.max_layer_idx = max_layer_idx + self.tie_qk_kernels = tie_qk_kernels + self.train_attention = train_attention + self.base_inference = False + self.fp32_attention = fp32_attention + self.track_state_grads = track_state_grads + if rank == 0: # multi-gpu + if fp32_attention and layer_idx == 0: + print(f"-> fp32_attention is {fp32_attention}") + if layer_idx == 0 and feature_map_kwargs is not None: + for k, v in feature_map_kwargs.items(): + print(f"-> {k}: {v}") + if layer_idx == 0 and learned_kernel_kwargs is not None: + for k, v in learned_kernel_kwargs.items(): + print(f"-> {k}: {v}") + + self.remove_base_attn = remove_base_attn + + # Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0) + self.rotary_config = rotary_config + # if isinstance(self.rotary_config, DictDefault): + # self.rotary_config = OmegaConf.to_container(self.rotary_config) + + self.rotary_emb = None + if self.base_config is not None and self.rotary_config is None: + self.rotary_emb = base_attn.rotary_emb + + self.init_weights_(base_attn, remove_base_attn) + self.init_feature_map_( + feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs + ) + + def init_feature_map_( + self, + feature_map: str, + feature_map_kwargs: dict, + learned_kernel: Optional[str] = None, + learned_kernel_kwargs: Optional[dict] = None, + ): + """ + Initialize MLP-based feature map + """ + self.fmap_gqa = False # Turn True if specified below + if learned_kernel is not None and learned_kernel_kwargs is not None: + # Ensure dict + learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()} + learned_kernel_kwargs["num_heads"] = self.num_heads + learned_kernel_kwargs["head_dim"] = self.head_dim + learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype + learned_kernel_kwargs["device"] = self.q_proj.weight.device + # Create MLP + mlp_learned_kernel = init_learned_kernel( + learned_kernel, **learned_kernel_kwargs + ) + # Add "activation"; see src.models.feature_map.py + self.feature_map_q = init_feature_map( + name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs + ) + if self.tie_qk_kernels: # tie mlp weights for query and key feature maps + self.feature_map_k = self.feature_map_q + else: + self.feature_map_k = copy.deepcopy(self.feature_map_q) + + def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True): + """ + Initialize module layers, weights, positional dependencies, etc. + from original softmax attention layer (base_attn) + """ + # Make other attributes accessible + self.attention_dropout = 0 # We don't use dropout + self.hidden_size = base_attn.hidden_size + self.num_heads = base_attn.num_heads + self.head_dim = base_attn.head_dim + self.num_key_value_heads = base_attn.num_key_value_heads + self.num_key_value_groups = base_attn.num_key_value_groups + + self.q_shape = [self.num_heads, self.head_dim] + self.k_shape = [self.num_key_value_heads, self.head_dim] + self.v_shape = [self.num_key_value_heads, self.head_dim] + device = base_attn.q_proj.weight.device + # Rotary embeddings + if self.rotary_emb is None: + self.max_position_embeddings = base_attn.max_position_embeddings + scaling_factor = getattr(base_attn.rotary_emb, "scaling_factor", 1.0) + if self.rotary_config is None: + self.rotary_emb = get_rotary_embeddings( + rope_scaling_type=None, + head_dim=self.head_dim, + max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings, + rope_theta=base_attn.rotary_emb.base, + rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor, + device=device, + ) + else: + if "device" not in self.rotary_config: + self.rotary_config["device"] = device + self.rotary_emb = get_rotary_embeddings(**self.rotary_config) + + # Copy original model projection layers + self.q_proj = base_attn.q_proj + self.k_proj = base_attn.k_proj + self.v_proj = base_attn.v_proj + self.o_proj = base_attn.o_proj + try: # If wanting to use FA2 for ground-truth inference + self._flash_attn_uses_top_left_mask = ( + base_attn._flash_attn_uses_top_left_mask + ) + except AttributeError: + pass + + if self.remove_base_attn or remove_base_attn: + del base_attn # We don't need to keep these around + else: + self.base_attn = base_attn # For some training runs helpful to just call + + def process_qkv( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Any] = None, + ): # "legacy" cache approach + """ + Compute queries, keys, and values + """ + b, l, _ = hidden_states.size() + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + kv_seq_len = k.shape[-2] + + # Shape is (batch_size, seq_len, num_heads, head_dim) + q = q.view(b, l, *self.q_shape).transpose(1, 2) + k = k.view(b, l, *self.k_shape).transpose(1, 2) + v = v.view(b, l, *self.v_shape).transpose(1, 2) + + if ( + past_key_value is not None + ): # and k.shape[2] > q.shape[2]: # e.g., when generating + past_key_value.window_size = getattr( + self, "decode_window_size", None + ) # self.decode_window_size + if isinstance( + past_key_value, Cache + ): # In Transformers v4.36+ this is a DynamicCache object + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx + ) + else: + kv_seq_len += past_key_value[0].shape[-2] + + # Apply rotary embeddings and repeat for GQA + if position_ids is not None and kv_seq_len <= position_ids[0, -1]: + kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids + + if self.rotary_emb is None: + raise ValueError("Rotary embeddings not initialized") + + try: # As in Transformers v4.36 + cos, sin = self.rotary_emb(k, seq_len=kv_seq_len) + q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) + except TypeError: # As in Transformers v4.39+ + cos, sin = self.rotary_emb(v, position_ids) + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + k = repeat_kv(k, self.num_key_value_groups) + v = repeat_kv(v, self.num_key_value_groups) + return q, k, v, kv_seq_len + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Any] = None, # "legacy" cache approach + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + """ + Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36) + - Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + q, k, v, kv_seq_len = self.process_qkv( + hidden_states, attention_mask, position_ids, past_key_value + ) + + if self.base_inference: + with torch.no_grad(): + # 1. Compute "ground-truth" attention output and weights + y_true, _, _ = softmax_attention(q, k, v, causal=True) + y_true = ( + y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + ) + y_true = self.o_proj(y_true) + attn_weights = (None, None) + + elif self.train_attention: # Distilling / learning attentions + # Note for now we assume no padding when distilling; attention masks only enforce causality + assert ( + output_attentions is True + ), f"When training feature maps, output_attentions should be True but is {output_attentions}" + with torch.no_grad(): + # 1. Compute "ground-truth" attention output and weights + _y_true, attn_true, _ = softmax_attention(q, k, v, causal=True) + y_true = ( + _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + ) + y_true = self.o_proj(y_true) + + # 2. Compute "predicted" attention (just weights) + q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k) + y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True) + attn_weights = ( # type: ignore + (attn_pred, attn_true), + (y_pred, _y_true), + ) # Save both attention weights so we can supervise. + + else: # Finetuning + q, k = self.feature_map_q(q), self.feature_map_k(k) + # Apply prefill mask + if attention_mask is not None and q.shape[2] > 1: + if len(attention_mask.shape) == 4: + lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][ + ..., None + ] # b, 1, k_len, 1 + else: + lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 + k = k.masked_fill(~lin_attn_mask, 0) + + if past_key_value is not None: # Initialize states + if len(past_key_value.kv_states) == self.layer_idx: + b, h, _, f = k.shape + past_key_value.kv_states.append( + torch.zeros( + b, h, f, self.head_dim, dtype=q.dtype, device=q.device + ) + ) + past_key_value.k_states.append( + torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device) + ) + # Generating + if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None: + assert use_cache is True + kv_state, k_state = past_key_value.update( + k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention + ) + if self.fp32_attention: + q = q.float() + y_true = ( + torch.einsum("bhlf,bhfd->bhld", q, kv_state.float()) + / torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[ + ..., None + ] + ).to(dtype=k.dtype) + else: + y_true = ( + torch.einsum("bhlf,bhfd->bhld", q, kv_state) + / torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None] + ) + else: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + y_true, _, _ = linear_attention( + q, k, v, self.fp32_attention, self.eps + ) # Ordinarily the states are ignored + past_key_value.update( + k.detach(), + v.detach(), + self.layer_idx, + accumulate_in_fp32=self.fp32_attention, + ) + # doing some unnecessary recomputation here + else: + y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps) + + # Concatenate heads and apply output projection + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + attn_weights = None + + return y_true, attn_weights, past_key_value + + +class LinearAttentionState(Cache): + """ + Handle the KV and K states for linear attention + - Adopts HF Transformers `past_key_values` convention + - Inherits from `Cache` class + - Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + + def __init__(self) -> None: + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """ + Returns the sequence length of the cached states. A layer index can be optionally passed. + """ + if layer_idx is None: + raise ValueError("Layer index must not be None") + + if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states + self._seen_tokens_by_layer.append(0) + return self._seen_tokens_by_layer[layer_idx] + + def get_max_length(self) -> Optional[int]: + """ + Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length. + """ + return None + + def get_usable_length( + self, new_seq_length: int, layer_idx: Optional[int] = 0 + ) -> int: + """Given the sequence length of the new inputs, returns the usable length of the cache.""" + # Cache without size limit -> all cache is usable + # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache + # length, we will need to evict part of the cache (and thus not all cache is usable) + max_length = self.get_max_length() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: Optional[int] = None, + cache_kwargs: Optional[Any] = None, + accumulate_in_fp32: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if layer_idx is None: + raise ValueError("Layer index must not be None") + + with torch.no_grad(): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + dtype = key_states.dtype + if accumulate_in_fp32: + key_states, value_states = key_states.float(), value_states.float() + + kv_state = torch.einsum( + "bhlf,bhld->bhfd", key_states, value_states + ).detach() + k_state = key_states.sum( + dim=-2, keepdim=True + ).detach() # b, h, 1, f; note the 1 + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + print( + "if len(self.k_states) <= layer_idx: # Initializing kv and k states" + ) + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + else: + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to( + dtype + ) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to( + dtype + ) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def to_legacy_cache(self): + """Hack, but just return self""" + return self + + def reorder_cache(self, beam_idx: torch.LongTensor): + """ + Reorders the cache for beam search, given the selected beam indices. + -> Copied from transformers/src/transformers/cache_utils.py + """ + raise NotImplementedError( + "Reordering cache not implemented for LinearAttentionState" + ) diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw.py b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw.py new file mode 100644 index 000000000..a75cc3cc3 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw.py @@ -0,0 +1,460 @@ +""" +Subquadratic attention combining sliding window and linear attentions +- Using "standard" sliding windows +- Didactically computes outputs with n^2 attention weights for now +- Copied + adapted from linear_window_attention_tk.py for single-file reference + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" + +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache + +from .linear_attention import ( + LinearAttentionState, + LolcatsLinearAttention, + softmax_attention, +) + + +# ---------------------- +# Sliding window helpers +# ---------------------- +def get_masks( + window_size: int, q_len: int, k_len: int, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Return masks for softmax and linear attention terms + -> 1 is include, 0 is ignore + """ + causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril( + k_len - q_len + ) + linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril( + k_len - q_len - window_size + ) + window_mask = causal_mask - linear_mask + # Return softmax mask (window), linear attention mask + # -> shapes broadcast over (b, h, q_len, k_len) + return window_mask[None, None, ...], linear_mask[None, None, ...] + + +def hybrid_attention_quadratic( + q: torch.Tensor, + k: torch.Tensor, + f_q: torch.Tensor, + f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: Optional[torch.Tensor] = None, + k_state: Optional[torch.Tensor] = None, + eps: float = 1e-12, + mask_value: float = -1e8, +): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = get_masks( + window_size, q.shape[-2], k.shape[-2], q.device + ) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float()) + if ( + kv_state is not None and k_state is not None + ): # Combine with prior kv_state and k_state + y += linear_factor * torch.einsum( + "bhld,bhdf->bhlf", f_q.float(), kv_state.float() + ) + sum_ln += ( + linear_factor + * torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None] + ) + y = (y / (sum_sm + sum_ln)).to(q.dtype) + return y, a # attention weights only for the last chunk + + +# --------------------- +# Attention layer class +# --------------------- +class LolcatsSlidingWindowAttention(LolcatsLinearAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + + def __init__( + self, + window_size: int = 64, + decode_window_size: Optional[int] = None, + affine_attention_factors: bool = False, + init_window_factor: float = 0, + train_window_factor: bool = True, + state_grad_enabled: bool = False, + **kwargs, + ): + self.window_size = window_size + self.decode_window_size = ( + decode_window_size if decode_window_size is not None else window_size + ) + self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1} + super().__init__(**kwargs) + self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw' + # Determine how we compute attentions + self.quadratic_attention = hybrid_attention_quadratic + self.attention_type = kwargs[ + "attention_type" + ] # 'hedgehog_long_llama_window_sw' + # Learnable factor for combining attentions + self.affine_attention_factors = affine_attention_factors + device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype + if train_window_factor: + self.window_factors = nn.Parameter( + init_window_factor + * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) + ) + else: + self.register_buffer( + "window_factors", + init_window_factor + * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype), + ) + # Whether we use original flash attention 2 inference (use during attention transfer) + self.base_inference = False + self.state_grad_enabled = state_grad_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + q, k, v, kv_seq_len = self.process_qkv( + hidden_states, attention_mask, position_ids, past_key_value + ) + f_q, f_k = self.feature_map_q(q), self.feature_map_k( + k + ) # Have to do after repeat for grouped-query attn if we use same fmap + + if self.train_attention: + # 1. Compute "ground-truth" attention output and weights + with torch.no_grad(): + _y_true, a_true = softmax_attention(q, k, v)[:2] + y_true = ( + _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + ) + y_true = self.o_proj(y_true) + + # 2. Compute "predicted" attention outputs + # compute attn weights under sliding window + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + ) + attn_weights = ((a_pred, a_true), (y_pred, _y_true)) + else: + attn_weights = None + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_true, a_pred = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + ) + attn_weights = a_pred + else: + past_key_value.window_size = self.decode_window_size + if ( + f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training + ): # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding( + k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype + ) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + + # Softmax attention terms + a_sm = torch.einsum( + "bhmd,bhnd->bhmn", q.float(), k_cache.float() + ) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = torch.einsum( + "bhmn,bhnd->bhmd", a_sm, v_cache.float() + ) + linear_factors * torch.einsum( + "bhlf,bhfd->bhld", f_q.float(), f_kv_state.float() + ) + sum_ln = ( + linear_factors + * torch.einsum( + "bhlf,bhnf->bhl", f_q.float(), f_k_state.float() + )[..., None] + ) + y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_true, _ = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state, + ) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update( + k, + v, + self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True, + ) + # Concatenate heads and apply output projection + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + return y_true, attn_weights, past_key_value + + +class LinearAttentionSlidingWindowCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a "KV state" and "K state" + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + # Account for sliding windows + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.window_size = window_size + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: Optional[int] = None, + cache_kwargs: Optional[Any] = None, + accumulate_in_fp32: bool = False, + fmap_key_states: Optional[torch.Tensor] = None, # should not be None + grad_enabled: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV, K states; and KV cache during training + - For decoding, use `self.decode_kv_states` to keep track of KV states + up to sliding window terms + - For (chunked) training, use `self.kv_states` to keep track of KV states + up to end of sequence + - Likewise for `self.decode_k_states` and `self.k_states` + """ + if fmap_key_states is None: + raise ValueError("fmap_key_states must not be None") + + if layer_idx is None: + raise ValueError("Layer index must not be None") + + with torch.set_grad_enabled(grad_enabled): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + dtype = key_states.dtype + if accumulate_in_fp32: + # key_states = key_states.float() + fmap_key_states = fmap_key_states.float() + value_states = value_states.float() + + # Decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum( + "bhlf,bhld->bhfd", + fmap_key_states[:, :, : -self.window_size], + value_states[:, :, : -self.window_size], + ) + # KV state + kv_state = decode_kv_state + torch.einsum( + "bhlf,bhld->bhfd", + fmap_key_states[:, :, -self.window_size :], + value_states[:, :, -self.window_size :], + ) + # shape is b, h, 1, f; note the 1 + decode_k_state = fmap_key_states[:, :, : -self.window_size].sum( + dim=-2, keepdim=True + ) + k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum( + dim=-2, keepdim=True + ) + + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + + self.decode_kv_states.append(decode_kv_state.to(dtype)) + self.decode_k_states.append(decode_k_state.to(dtype)) + + self.k_cache.append(key_states[:, :, -self.window_size :, :]) + self.v_cache.append( + value_states[:, :, -self.window_size :, :].to(dtype) + ) + # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) + else: + # Update kv and k states recurrently + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to( + dtype + ) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to( + dtype + ) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + + decode_kv_state = ( + self.decode_kv_states[layer_idx].to(kv_state.dtype) + + decode_kv_state + ).to(dtype) + decode_k_state = ( + self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state + ).to(dtype) + self.decode_kv_states[layer_idx] = decode_kv_state + self.decode_k_states[layer_idx] = decode_k_state + + self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :] + self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :] + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def update_for_decoding( + self, + keys: torch.Tensor, + values: torch.Tensor, + layer_idx: int, + feature_map_k: Callable, + dtype: torch.dtype, + ): + """ + Update the decoding KV and K states, and KV cache, during decodeing + """ + with torch.no_grad(): + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + + if k_cache.shape[-2] < self.window_size: # build window-size cache + self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) + else: + # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size + # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache + # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device) + # else: + # f_k_state = feature_map_k(k_cache[:, :, :1, :]) + # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum( + "bhlf,bhld->bhfd", k_state.float(), v_state.float() + ).to( + dtype + ) # b, h, f, d + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat( + [k_cache[:, :, 1:, :], keys], dim=-2 + ) + self.v_cache[layer_idx] = torch.cat( + [v_cache[:, :, 1:, :], values], dim=-2 + ) + + if layer_idx == 0: + self._seen_tokens += keys.shape[-2] + self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] + return ( + self.k_cache[layer_idx], + self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], + self.decode_k_states[layer_idx], + ) diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_linear.py b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_linear.py new file mode 100644 index 000000000..27c5db46c --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_linear.py @@ -0,0 +1,687 @@ +""" +Subquadratic attention combining sliding window and linear attentions +- Using "standard" sliding windows +- Didactically computes outputs with n^2 attention weights for now +- Copied + adapted from linear_window_attention_tk.py for single-file reference + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" + +import logging +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache + +try: + from transformers.modeling_flash_attention_utils import _flash_attention_forward +except ModuleNotFoundError: + _flash_attention_forward = None # Transformers v4.36 + +from ..model.rotary import apply_rotary_pos_emb + +# Causal linear attention dot product CUDA kernel from fast-transformers +from .linear_attention import ( + LinearAttentionState, + LolcatsLinearAttention, + causal_dot_product, +) + +LOG = logging.getLogger( + "axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long" +) + + +# ---------------------- +# Sliding window helpers +# ---------------------- +def get_masks( + window_size: int, q_len: int, k_len: int, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Return masks for softmax and linear attention terms + -> 1 is include, 0 is ignore + """ + causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril( + max(k_len - q_len, 0) + ) + linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril( + max(k_len - q_len, 0) - window_size + ) + window_mask = causal_mask - linear_mask + # Return softmax mask (window), linear attention mask + # -> shapes broadcast over (b, h, q_len, k_len) + return window_mask[None, None, ...], linear_mask[None, None, ...] + + +def hybrid_attention_quadratic( + q: torch.Tensor, + k: torch.Tensor, + f_q: torch.Tensor, + f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: Optional[torch.Tensor] = None, + k_state: Optional[torch.Tensor] = None, + eps: float = 1e-12, + mask_value: float = -1e8, +): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = get_masks( + window_size, q.shape[-2], k.shape[-2], q.device + ) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float()) + if ( + kv_state is not None and k_state is not None + ): # Combine with prior kv_state and k_state + y += linear_factor * torch.einsum( + "bhld,bhdf->bhlf", f_q.float(), kv_state.float() + ) + sum_ln += ( + linear_factor + * torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None] + ) + y = (y / (sum_sm + sum_ln)).to(q.dtype) + return y, a # attention weights only for the last chunk + + +# ------------------------------ +# Hybrid window attention linear +# ------------------------------ +def under_window_linear_attention( + f_q: torch.Tensor, + f_k: torch.Tensor, + v: torch.Tensor, + window_size: int, + linear_factor: torch.Tensor, + eps: float = 1e-12, +): + """Compute hybrid window attention dot product with linear complexity in q_len""" + dtype = f_q.dtype + w = window_size + f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :] + v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :] + qkv = linear_factor * causal_dot_product( + f_q.contiguous().to(dtype=torch.float32), + f_k.contiguous().to(dtype=torch.float32), + v.contiguous().to(dtype=torch.float32), + ).to(dtype=dtype) + sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype) + sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None] + sum_qk[sum_qk == 0] += eps + return qkv, sum_qk + + +def sliding_window_softmax_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + window_size: int, + window_factor: torch.Tensor, + mask_value: float = -1e8, +): + """ + Compute sliding window softmax attention without materializing + O(seq_len^2) attention weights + """ + d = q.shape[-1] + # Compute windows for keys + window_kwargs = {"dimension": 2, "size": window_size, "step": 1} + k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) + v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs) + + # Compute windowed_softmax(qk); causal in its construction + a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5) + a_sm[a_sm == 0] = -torch.finfo( + q.dtype + ).max # heuristic for zeroing out padding above + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm + # return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v) + + +def hybrid_attention_linear( + q: torch.Tensor, + k: torch.Tensor, + f_q: torch.Tensor, + f_k: torch.Tensor, + v: torch.Tensor, + window_factor: Optional[torch.Tensor] = None, + linear_factor: Optional[torch.Tensor] = None, + window_size: int = 64, + kv_state: Optional[torch.Tensor] = None, + k_state: Optional[torch.Tensor] = None, + eps: float = 1e-12, + mask_value: float = -1e8, +): + """ + Alternative hybrid attention combining sliding window and linear attentions + -> Uses O(n) memory if n is sequence length by padding and unfolding windows + """ + # window_kwargs = {"dimension": 2, "size": window_size, "step": 1} + if window_factor is None: + raise ValueError("window_factor must be provided") + + if linear_factor is None: + raise ValueError("linear_factor must be provided") + + # 1. Sliding window (softmax attention) + with torch.no_grad(): + qkv_sm, sum_qk_sm = sliding_window_softmax_attention( + q, k, v, window_size, window_factor, mask_value + ) + + # 2. Under window (linear attention) + qkv_ln, sum_qk_ln = under_window_linear_attention( + f_q, f_k, v, window_size, linear_factor, eps + ) + + # 3. Combine + y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln) + return y, None + + +# --------------------- +# Attention layer class +# --------------------- +class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + + def __init__( + self, + window_size: int = 64, + decode_window_size: Optional[int] = None, + affine_attention_factors: bool = False, + init_window_factor: float = 0, + train_window_factor: bool = True, + state_grad_enabled: bool = False, + **kwargs, + ): + self.window_size = window_size + self.decode_window_size = ( + decode_window_size if decode_window_size is not None else window_size + ) + self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1} + super().__init__(**kwargs) + # Determine how we compute attentions + self.linear_attention = hybrid_attention_linear + self.attention_type = "lolcats_llama_window_sw" + # Learnable factor for combining attentions + self.affine_attention_factors = affine_attention_factors + device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype + if train_window_factor: + self.window_factors = nn.Parameter( + init_window_factor + * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) + ) + else: + self.register_buffer( + "window_factors", + init_window_factor + * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype), + ) + # Whether we use original flash attention 2 inference (use during attention transfer) + self.base_inference = False + self.state_grad_enabled = state_grad_enabled + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + + if self.train_attention and self.base_inference: + with torch.no_grad(): + _y_true = flash_attention_2( + self, # self.base_attn, + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + use_cache=False, + )[0] + # _y_true.shape is (batch_size, seq_len, num_heads, head_dim) + y_true = _y_true.reshape(b, l, -1).contiguous() + y_true = self.o_proj(y_true) + # layer_io = (hidden_states, _y_true) # hack + layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack + return y_true, layer_io, None + + else: + q, k, v, kv_seq_len = self.process_qkv( + hidden_states, attention_mask, position_ids, past_key_value + ) + f_q, f_k = self.feature_map_q(q), self.feature_map_k( + k + ) # Have to do after repeat for grouped-query attn if we use same fmap + + attn_weights = None + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_true, a_pred = self.linear_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + ) + attn_weights = a_pred + else: + past_key_value.window_size = self.decode_window_size + if ( + f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training + ): # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding( + k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype + ) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + + # Softmax attention terms + a_sm = torch.einsum( + "bhmd,bhnd->bhmn", q.float(), k_cache.float() + ) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = torch.einsum( + "bhmn,bhnd->bhmd", a_sm, v_cache.float() + ) + linear_factors * torch.einsum( + "bhlf,bhfd->bhld", f_q.float(), f_kv_state.float() + ) + sum_ln = ( + linear_factors + * torch.einsum( + "bhlf,bhnf->bhl", f_q.float(), f_k_state.float() + )[..., None] + ) + y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_true, _ = self.linear_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state, + ) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update( + k, + v, + self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True, + ) + # Concatenate heads and apply output projection + _y_true = y_true.transpose(1, 2).contiguous() + y_true = self.o_proj(_y_true.view(b, l, self.hidden_size)) + + if self.train_attention: + attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d) + return y_true, attn_weights, past_key_value + + +class LinearAttentionSlidingWindowCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a "KV state" and "K state" + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + # Account for sliding windows + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.window_size = window_size + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: Optional[int] = None, + cache_kwargs: Optional[Any] = None, + accumulate_in_fp32: bool = False, + fmap_key_states: Optional[torch.Tensor] = None, # should not be None + grad_enabled: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV, K states; and KV cache during training + - For decoding, use `self.decode_kv_states` to keep track of KV states + up to sliding window terms + - For (chunked) training, use `self.kv_states` to keep track of KV states + up to end of sequence + - Likewise for `self.decode_k_states` and `self.k_states` + """ + if fmap_key_states is None: + raise ValueError("fmap_key_states must not be None") + + if layer_idx is None: + raise ValueError("Layer index must not be None") + + with torch.set_grad_enabled(grad_enabled): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + dtype = key_states.dtype + if accumulate_in_fp32: + # key_states = key_states.float() + fmap_key_states = fmap_key_states.float() + value_states = value_states.float() + + # Decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum( + "bhlf,bhld->bhfd", + fmap_key_states[:, :, : -self.window_size], + value_states[:, :, : -self.window_size], + ) + # KV state + kv_state = decode_kv_state + torch.einsum( + "bhlf,bhld->bhfd", + fmap_key_states[:, :, -self.window_size :], + value_states[:, :, -self.window_size :], + ) + # shape is b, h, 1, f; note the 1 + decode_k_state = fmap_key_states[:, :, : -self.window_size].sum( + dim=-2, keepdim=True + ) + k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum( + dim=-2, keepdim=True + ) + + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + + self.decode_kv_states.append(decode_kv_state.to(dtype)) + self.decode_k_states.append(decode_k_state.to(dtype)) + + self.k_cache.append(key_states[:, :, -self.window_size :, :]) + self.v_cache.append( + value_states[:, :, -self.window_size :, :].to(dtype) + ) + # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) + else: + # Update kv and k states recurrently + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to( + dtype + ) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to( + dtype + ) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + + decode_kv_state = ( + self.decode_kv_states[layer_idx].to(kv_state.dtype) + + decode_kv_state + ).to(dtype) + decode_k_state = ( + self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state + ).to(dtype) + self.decode_kv_states[layer_idx] = decode_kv_state + self.decode_k_states[layer_idx] = decode_k_state + + self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :] + self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :] + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def update_for_decoding( + self, + keys: torch.Tensor, + values: torch.Tensor, + layer_idx: int, + feature_map_k: Callable, + dtype: torch.dtype, + ): + """ + Update the decoding KV and K states, and KV cache, during decodeing + """ + with torch.no_grad(): + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + + if k_cache.shape[-2] < self.window_size: # build window-size cache + self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) + else: + # MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size + # if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache + # f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device) + # else: + # f_k_state = feature_map_k(k_cache[:, :, :1, :]) + # -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum( + "bhlf,bhld->bhfd", k_state.float(), v_state.float() + ).to( + dtype + ) # b, h, f, d + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat( + [k_cache[:, :, 1:, :], keys], dim=-2 + ) + self.v_cache[layer_idx] = torch.cat( + [v_cache[:, :, 1:, :], values], dim=-2 + ) + + if layer_idx == 0: + self._seen_tokens += keys.shape[-2] + self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] + return ( + self.k_cache[layer_idx], + self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], + self.decode_k_states[layer_idx], + ) + + +# ----------------- +# Flash Attention 2 +# ----------------- + + +def flash_attention_2( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +): + """ + Wrapper for LlamaFlashAttention2 + Copied and modified from HF Transformers v4.36 and v4.43 implementations + - (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402 + - (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456 + """ + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + try: # As in Transformers v4.36 + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + except Exception: # As in Transformers v4.39 + cos, sin = self.rotary_emb(key_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + LOG.debug( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self, "_flash_attention_forward", False): + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=True, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, # dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + return attn_output, past_key_value diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_long.py b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_long.py new file mode 100644 index 000000000..cd9176222 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_sw_long.py @@ -0,0 +1,24 @@ +""" +LoLCATs attention combining sliding window and linear attentions +- Using standard sliding window arrangement +- Training over long sequences with fixed memory with recurrent view +- During attention transfer, use Flash Attention to compute softmax attention outputs + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" +from .linear_window_attention_sw import hybrid_attention_quadratic +from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention + + +class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + + def __init__(self, remove_base_attn=True, **kwargs): + # keep self.base_attn for Flash Attention inference + super().__init__(remove_base_attn=True, **kwargs) + self.quadratic_attention = hybrid_attention_quadratic diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk.py b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk.py new file mode 100644 index 000000000..3c22bb5be --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk.py @@ -0,0 +1,466 @@ +""" +Subquadratic attention combining sliding window and linear attentions +- Using the TK "terracing" arrangement + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" + +import math +from typing import Any, Callable, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache + +from .linear_attention import ( + LinearAttentionState, + LolcatsLinearAttention, + softmax_attention, +) + + +# ---------------------- +# Sliding window helpers +# ---------------------- +def get_masks( + window_size: int, q_len: int, k_len: int, device: torch.device +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Return masks for softmax and linear attention terms + -> 1 is include, 0 is ignore + """ + win_len = window_size + m = math.ceil(max(q_len, k_len) / window_size) + # Creates an n x n mask where n = window_size^2 + mask = torch.block_diag( + *[ + torch.ones( + (win_len, win_len), + ) + ] + * m + ) + mask += torch.roll(mask, -win_len, -1) # this adds the terracing + if mask.shape[0] > q_len: + mask = mask[-q_len:] + if mask.shape[1] > k_len: + mask = mask[:, -k_len:] + # Return softmax mask (window), linear attention mask + mask = mask[None, None, ...] # b, h, q_len, k_len + return ( + torch.tril(mask).to(device=device, dtype=torch.int), + torch.tril(1 - mask).to(device=device, dtype=torch.int), + ) + + +def hybrid_attention_quadratic( + q: torch.Tensor, + k: torch.Tensor, + f_q: torch.Tensor, + f_k: torch.Tensor, + v: torch.Tensor, + window_factor: torch.Tensor, + linear_factor: torch.Tensor, + window_size: int, + kv_state: Optional[torch.Tensor] = None, + k_state: Optional[torch.Tensor] = None, + eps: float = 1e-12, + mask_value: float = -1e8, +): + """ + Hybrid attention combining sliding window and linear attentions + """ + + mask_window, mask_linear = get_masks( + window_size, q.shape[-2], k.shape[-2], q.device + ) + + # 1. Sliding window (softmax attention) + a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5) + a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value) + # torch.softmax(a_sm, dim=-1), but we account for the max when combining + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factor * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # 2. Under window (linear attention) + a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float()) + a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0) + sum_ln = a_ln.sum(dim=-1, keepdim=True) + + # 3. Combine + a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights + # Allow outputs to also depend on prior kv_state and k_state + y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float()) + if ( + kv_state is not None and k_state is not None + ): # Combine with prior kv_state and k_state + y += linear_factor * torch.einsum( + "bhld,bhdf->bhlf", f_q.float(), kv_state.float() + ) + sum_ln += ( + linear_factor + * torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None] + ) + y = (y / (sum_sm + sum_ln)).to(q.dtype) + return y, a # attention weights only for the last chunk + + +# --------------------- +# Attention layer class +# --------------------- +class LolcatsTKWindowAttention(LolcatsLinearAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + + def __init__( + self, + window_size: int = 64, + decode_window_size: Optional[int] = None, + affine_attention_factors: bool = False, + init_window_factor: float = 0, + train_window_factor: bool = True, + state_grad_enabled: bool = False, + **kwargs, + ): + self.window_size = window_size + self.decode_window_size = ( + decode_window_size if decode_window_size is not None else window_size + ) + self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1} + super().__init__(**kwargs) + self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk' + # Determine how we compute attentions + self.quadratic_attention = hybrid_attention_quadratic + self.attention_type = kwargs[ + "attention_type" + ] # 'hedgehog_long_llama_window_tk' + # Learnable factor for combining attentions + self.affine_attention_factors = affine_attention_factors + device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype + if train_window_factor: + self.window_factors = nn.Parameter( + init_window_factor + * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype) + ) + else: + self.register_buffer( + "window_factors", + init_window_factor + * torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype), + ) + # Whether we use original flash attention 2 inference (use during attention transfer) + self.base_inference = False + self.state_grad_enabled = state_grad_enabled + self.window_factor = self.window_factors # legacy naming support + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + q, k, v, kv_seq_len = self.process_qkv( + hidden_states, attention_mask, position_ids, past_key_value + ) + f_q, f_k = self.feature_map_q(q), self.feature_map_k( + k + ) # Have to do after repeat for grouped-query attn if we use same fmap + + if self.train_attention: + # 1. Compute "ground-truth" attention output and weights + with torch.no_grad(): + _y_true, a_true = softmax_attention(q, k, v)[:2] + y_true = ( + _y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + ) + y_true = self.o_proj(y_true) + + # 2. Compute "predicted" attention outputs + # compute attn weights under sliding window + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + ) + attn_weights = ((a_pred, a_true), (y_pred, _y_true)) + else: + attn_weights = None + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_true, a_pred = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + ) + attn_weights = a_pred + else: + past_key_value.window_size = self.decode_window_size + if ( + f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training + ): # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding( + k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype + ) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + + # Softmax attention terms + a_sm = torch.einsum( + "bhmd,bhnd->bhmn", q.float(), k_cache.float() + ) * (k.shape[-1] ** -0.5) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = torch.einsum( + "bhmn,bhnd->bhmd", a_sm, v_cache.float() + ) + linear_factors * torch.einsum( + "bhlf,bhfd->bhld", f_q.float(), f_kv_state.float() + ) + sum_ln = ( + linear_factors + * torch.einsum( + "bhld,bhnd->bhl", f_q.float(), f_k_state.float() + )[..., None] + ) + y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_true, _ = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state, + ) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update( + k, + v, + self.layer_idx, + fmap_key_states=f_k, + accumulate_in_fp32=True, + ) + # Concatenate heads and apply output projection + y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + return y_true, attn_weights, past_key_value + + +class LinearAttentionTKWindowCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a "KV state" and "K state" + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.kv_states: List[torch.Tensor] = [] + self.k_states: List[torch.Tensor] = [] + + # Account for sliding windows + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + self.window_size = window_size + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: Optional[int] = None, + cache_kwargs: Optional[Any] = None, + accumulate_in_fp32: bool = False, + fmap_key_states: Optional[torch.Tensor] = None, # should not be None + grad_enabled: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Update KV, K states; and KV cache during training + - For decoding, use `self.decode_kv_states` to keep track of KV states + up to sliding window terms + - For (chunked) training, use `self.kv_states` to keep track of KV states + up to end of sequence + - Likewise for `self.decode_k_states` and `self.k_states` + """ + if fmap_key_states is None: + raise ValueError("fmap_key_states should not be None") + + if layer_idx is None: + raise ValueError("layer_idx should not be None") + + with torch.set_grad_enabled(grad_enabled): + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + dtype = key_states.dtype + if accumulate_in_fp32: + # key_states = key_states.float() + fmap_key_states = fmap_key_states.float() + value_states = value_states.float() + + # Decoding KV state (KV terms up to last window_size) + decode_kv_state = torch.einsum( + "bhlf,bhld->bhfd", + fmap_key_states[:, :, : -self.window_size], + value_states[:, :, : -self.window_size], + ) + # KV state + kv_state = decode_kv_state + torch.einsum( + "bhlf,bhld->bhfd", + fmap_key_states[:, :, -self.window_size :], + value_states[:, :, -self.window_size :], + ) + # shape is b, h, 1, f; note the 1 + decode_k_state = fmap_key_states[:, :, : -self.window_size].sum( + dim=-2, keepdim=True + ) + k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum( + dim=-2, keepdim=True + ) + + # Update the cache + if len(self.k_states) <= layer_idx: # Initializing kv and k states + self.kv_states.append(kv_state.to(dtype)) + self.k_states.append(k_state.to(dtype)) + + self.decode_kv_states.append(decode_kv_state.to(dtype)) + self.decode_k_states.append(decode_k_state.to(dtype)) + + self.k_cache.append(key_states[:, :, -self.window_size :, :]) + self.v_cache.append( + value_states[:, :, -self.window_size :, :].to(dtype) + ) + # self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2]) + else: + # Update kv and k states recurrently + kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to( + dtype + ) + k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to( + dtype + ) + self.kv_states[layer_idx] = kv_state + self.k_states[layer_idx] = k_state + + decode_kv_state = ( + self.decode_kv_states[layer_idx].to(kv_state.dtype) + + decode_kv_state + ).to(dtype) + decode_k_state = ( + self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state + ).to(dtype) + self.decode_kv_states[layer_idx] = decode_kv_state + self.decode_k_states[layer_idx] = decode_k_state + + self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :] + self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :] + self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2] + + return self.kv_states[layer_idx], self.k_states[layer_idx] + + def update_for_decoding( + self, + keys: torch.Tensor, + values: torch.Tensor, + layer_idx: int, + feature_map_k: Callable, + dtype: torch.dtype, + ): + """ + Update the decoding KV and K states, and KV cache, during decodeing + """ + with torch.no_grad(): + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + + if k_cache.shape[-2] < self.window_size: # build window-size cache + self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2) + else: + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum( + "bhlf,bhld->bhfd", k_state.float(), v_state.float() + ).to( + dtype + ) # b, h, f, d + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat( + [k_cache[:, :, 1:, :], keys], dim=-2 + ) + self.v_cache[layer_idx] = torch.cat( + [v_cache[:, :, 1:, :], values], dim=-2 + ) + + if layer_idx == 0: + self._seen_tokens += keys.shape[-2] + self._seen_tokens_by_layer[layer_idx] += keys.shape[-2] + return ( + self.k_cache[layer_idx], + self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], + self.decode_k_states[layer_idx], + ) diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_gen.py b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_gen.py new file mode 100644 index 000000000..9ac11acbf --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_gen.py @@ -0,0 +1,221 @@ +""" +LoLCATs + ThunderKittens linear attention + sliding window for generation +""" + +import logging +from typing import Any, Callable, List, Optional + +import torch +import torch.nn.functional as F + +from .linear_attention import LinearAttentionState +from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention + +LOG = logging.getLogger( + "axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen" +) + +try: + from thunderkittens import hedgehog as tk_window_hedgehog_attention + + LOG.debug("Successfully imported ThunderKittens for TK window attention") +except ImportError: + LOG.debug("Failed to import ThunderKittens for TK window attention") + + +class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention): + def __init__(self, *args, window_size: int = 64, **kwargs): + super().__init__(*args, **kwargs) + self.train_attention = False + self.base_inference = False + self.window_size = 64 # hard-coded support for TK kernel + self.decode_window_size = 64 + + b, h, l, d = 1, 32, 8192, 128 + self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda") + self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda") + self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Any] = None, # “legacy” cache approach + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + assert ( + past_key_value is not None + ), "past_key_value must be provided for generation" + assert ( + self.train_attention is False + ), "train_attention is not supported for generation" + assert ( + self.base_inference is False + ), "base_inference is not supported for generation" + assert use_cache is True, "use_cache must be True for generation" + past_key_value.window_size = self.decode_window_size + q, k, v, kv_seq_len = self.process_qkv( + hidden_states, attention_mask, position_ids, past_key_value + ) + if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill + f_q = self.feature_map_q(q) + _kv = past_key_value.update_for_decoding( + k, v, self.layer_idx, self.feature_map_k + ) + k_cache, v_cache, kv_state, k_state = _kv + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + + # Softmax attention terms + a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * ( + k.shape[-1] ** -0.5 + ) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + # Combine with linear attention terms + y_true = torch.einsum( + "bhmn,bhnd->bhmd", a_sm, v_cache.float() + ) + linear_factors * torch.einsum( + "bhld,bhdf->bhlf", f_q.float(), kv_state.float() + ) + sum_ln = ( + linear_factors + * torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[ + ..., None + ] + ) + self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype) + + else: # Process prefill + # Use TK-implemented linear + terrace window attention + b, h, l, d = q.shape + device = q.device + # tk.hedgehog arguments + # y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device) + # kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device) + # k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device) + betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32)) + alphas = ( + 1 - betas + if self.affine_attention_factors + else torch.ones(betas.shape, dtype=torch.float32, device=device) + ) + q_map = self.feature_map_q.mlp.layer + k_map = self.feature_map_k.mlp.layer + # Saves outputs to y_pred, k_state, kv_state, where we fuse: + # 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) + # 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d + # 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’, + # f_k[:, :, :-self.window_size], + # v[:, :, :-self.window_size]) # b, h, f, d + # 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d + + tk_window_hedgehog_attention( + q.contiguous(), + k.contiguous(), + v.contiguous(), + self.y_true, + self.k_state, + self.kv_state, + q_map, + k_map, + alphas, + betas, + ) + + past_key_value.update_with_kv( + self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx + ) + + # Concatenate heads and apply output projection + y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size) + y_true = self.o_proj(y_true) + return y_true, None, past_key_value + + +class LinearAttentionTKWindowGenerationCache(LinearAttentionState): + """ + Class for `past_key_values` + -> Alternative to KV cache; here we only maintain a “KV state” and “K state” + -> Modified from transformers.cache_utils.DynamicCache (v4.36) + """ + + def __init__(self, window_size: int = 64) -> None: + super().__init__() + self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36 + self._seen_tokens_by_layer: List[int] = [] + self.window_size = window_size + + self.decode_kv_states: List[torch.Tensor] = [] + self.decode_k_states: List[torch.Tensor] = [] + self.k_cache: List[torch.Tensor] = [] + self.v_cache: List[torch.Tensor] = [] + + def update_with_kv( + self, + kv_state: torch.Tensor, + k_state: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_idx: int, + ): + """ + Update the cache with new KV and K states + """ + if layer_idx == 0: + self._seen_tokens += k.shape[2] + self._seen_tokens_by_layer.append(k.shape[2]) + + # Initialize KV and K states + if len(self.decode_k_states) <= layer_idx: + self.decode_kv_states.append(kv_state) + self.decode_k_states.append(k_state) + else: # Update KV and K states + self.decode_kv_states[layer_idx] = ( + self.decode_kv_states[layer_idx] + kv_state + ) + self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state + + self.k_cache.append(k[:, :, -self.window_size :, :]) + self.v_cache.append(v[:, :, -self.window_size :, :]) + + def update_for_decoding( + self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable + ): + """ + Update the cache for decoding + """ + k_cache = self.k_cache[layer_idx] + v_cache = self.v_cache[layer_idx] + k_state = feature_map_k(k_cache[:, :, :1, :]) + v_state = v_cache[:, :, :1, :] + kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to( + k.dtype + ) + + self.decode_kv_states[layer_idx] += kv_state + self.decode_k_states[layer_idx] += k_state + + self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2) + self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2) + if layer_idx == 0: + self._seen_tokens += k.shape[-2] + self._seen_tokens_by_layer[layer_idx] += k.shape[-2] + return ( + self.k_cache[layer_idx], + self.v_cache[layer_idx], + self.decode_kv_states[layer_idx], + self.decode_k_states[layer_idx], + ) diff --git a/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_long.py b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_long.py new file mode 100644 index 000000000..25df64940 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/linear_window_attention_tk_long.py @@ -0,0 +1,305 @@ +""" +LoLCATs attention combining sliding window and linear attentions +- Using the TK "terracing" arrangement +- Training over long sequences with fixed memory with recurrent view +- During attention transfer, use Flash Attention to compute softmax attention outputs + +For each layer: +- We first compute (softmax) attention over sliding windows +- We then compute standard linear attention to "fill in" the earlier parts +- We combine to model the entire sequence +""" + +import logging +from typing import Optional + +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache + +try: + from transformers.modeling_flash_attention_utils import _flash_attention_forward +except ModuleNotFoundError: + _flash_attention_forward = None # Transformers v4.36 + +from ..model.rotary import apply_rotary_pos_emb +from .linear_attention import softmax_attention +from .linear_window_attention_tk import LolcatsTKWindowAttention + +LOG = logging.getLogger( + "axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long" +) + + +class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention): + """ + Lolcats attention combining sliding window and linear attention + """ + + def __init__(self, remove_base_attn=True, **kwargs): + # keep self.base_attn for Flash Attention inference + super().__init__(remove_base_attn=True, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + """ + Forward pass with the option to compute attention weights multiple ways + if self.train_attention is True + -> Consistent with HuggingFace Transformers for easy use with their pretrained models + """ + b, l, _ = hidden_states.size() + if self.train_attention and self.base_inference: + with torch.no_grad(): + # LOG.debug(hidden_states.shape) + _y_true = flash_attention_2( + self, # self.base_attn, + hidden_states=hidden_states, + attention_mask=None, + position_ids=position_ids, + past_key_value=None, + output_attentions=False, + # output_hidden_states=False, + use_cache=False, + )[0] + # _y_true.shape is (batch_size, seq_len, num_heads, head_dim) + y_true = _y_true.reshape(b, l, -1).contiguous() + y_true = self.o_proj(y_true) + layer_io = (hidden_states, _y_true) # hack + # layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack + return y_true, layer_io, None + + q, k, v, kv_seq_len = self.process_qkv( + hidden_states, attention_mask, position_ids, past_key_value + ) + f_q, f_k = self.feature_map_q(q), self.feature_map_k(k) + + # attention_mask = None # For now this is always True + if past_key_value is None: # Regular training + window_factors = F.sigmoid(self.window_factors) + linear_factors = 1 - window_factors if self.affine_attention_factors else 1 + y_pred, a_pred = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + ) + else: + past_key_value.window_size = self.decode_window_size + if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating + assert use_cache is True + _kv = past_key_value.update_for_decoding( + k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype + ) + k_cache, v_cache, f_kv_state, f_k_state = _kv + + # Sliding window + linear attention decode + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + + a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * ( + k.shape[-1] ** -0.5 + ) + # a_sm = torch.softmax(a_sm, dim=-1) + a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True) + a_sm = window_factors * torch.exp(a_sm - a_sm_max) + sum_sm = a_sm.sum(dim=-1, keepdim=True) + + y_pred = torch.einsum( + "bhmn,bhnd->bhmd", a_sm, v_cache.float() + ) + linear_factors * torch.einsum( + "bhlf,bhfd->bhld", f_q.float(), f_kv_state.float() + ) + sum_ln = ( + linear_factors + * torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[ + ..., None + ] + ) + y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype) + + else: # Stateful training + if ( + self.state_grad_enabled + and self.layer_idx == 0 + and position_ids is not None + ): + LOG.debug( + f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]" + ) + LOG.debug( + f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}" + ) + try: + kv_state = past_key_value.kv_states[self.layer_idx] + k_state = past_key_value.k_states[self.layer_idx] + except IndexError: + kv_state, k_state = None, None + window_factors = F.sigmoid(self.window_factors) + linear_factors = ( + 1 - window_factors if self.affine_attention_factors else 1 + ) + y_pred, a_pred = self.quadratic_attention( + q, + k, + f_q, + f_k, + v, + window_factors, + linear_factors, + window_size=self.window_size, + kv_state=kv_state, + k_state=k_state, + ) + # Save and update KV cache and states + # past_key_value.update(k, v.detach(), self.layer_idx, + # fmap_key_states=f_k.detach(), + # accumulate_in_fp32=True) + past_key_value.update( + k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True + ) + + # Concatenate heads and apply output projection + _y_pred = y_pred.transpose(1, 2).contiguous() + y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size)) + + if self.train_attention: + with torch.no_grad(): + a_true = softmax_attention(q, k, None, causal=True)[1] + attn_weights = (_y_pred, (a_pred, a_true)) + else: + attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d) + return y_pred, attn_weights, past_key_value + + +# ----------------- +# Flash Attention 2 +# ----------------- + + +def flash_attention_2( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, +): + """ + Wrapper for LlamaFlashAttention2 + Copied and modified from HF Transformers v4.36 and v4.43 implementations + - (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402 + - (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456 + """ + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + try: # As in Transformers v4.36 + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) + except Exception: # As in Transformers v4.39 + cos, sin = self.rotary_emb(key_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + LOG.debug( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self, "_flash_attention_forward", False): + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=True, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=0, # dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=True, + ) + return attn_output, past_key_value diff --git a/src/axolotl/integrations/lolcats/linear_attention/utils.py b/src/axolotl/integrations/lolcats/linear_attention/utils.py new file mode 100644 index 000000000..4e0314ce0 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linear_attention/utils.py @@ -0,0 +1,34 @@ +""" +Shared attention helpers +""" + +import torch + + +# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + The hidden states go from: + (batch, num_key_value_heads, seqlen, head_dim) to + (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def mask_attention( + qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000 +) -> torch.Tensor: + """ + Apply attention mask (e.g., for padding) + """ + if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l) + return qk_dot.masked_fill(~attn_mask.bool(), mask_value) + else: + return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value) diff --git a/src/axolotl/integrations/lolcats/linearize_attention.py b/src/axolotl/integrations/lolcats/linearize_attention.py new file mode 100644 index 000000000..c3a065ff2 --- /dev/null +++ b/src/axolotl/integrations/lolcats/linearize_attention.py @@ -0,0 +1,222 @@ +""" +Convert attention to linear attention + +Adapted from: https://github.com/HazyResearch/lolcats/blob/main/src/model/convert_model.py + +@misc{zhang2024lolcatslowranklinearizinglarge, + title={LoLCATs: On Low-Rank Linearizing of Large Language Models}, + author={Michael Zhang and Simran Arora and Rahul Chalamala and Alan Wu and Benjamin Spector and Aaryan Singhal and Krithik Ramesh and Christopher Ré}, + year={2024}, + eLOG.info={2410.10254}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2410.10254}, +} +""" + +import logging +from functools import partial +from typing import Any + +import torch.nn as nn +from tqdm import tqdm + +from axolotl.utils.dict import DictDefault + +LOG = logging.getLogger("axolotl.integrations.lolcats.linearize_attention") + + +def convert_attention( + model: nn.Module, + attention_config: DictDefault, + train_attention: bool = False, + remove_base_attn: bool = True, +): + """ + Call to convert all attention layers + """ + softmax_attns = [] + if "softmax_attentions" in attention_config: + softmax_attns = attention_config["softmax_attentions"] + if attention_config.attention_type != "softmax": + layers = traverse_layers(model) + for layer_idx, layer in enumerate( + tqdm(layers, desc="Converting attentions...") + ): + if layer_idx not in softmax_attns: + layer.self_attn = convert_llama_attention( + layer, + attention_config, + layers, + train_attention, + remove_base_attn, + ) + layer.self_attn.converted = True + else: # Freeze any preserved softmax attention layers + for p in layer.parameters(): + p.requires_grad = False + else: + LOG.info( + f"-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions" + ) + return model + + +def toggle_attention(llama_model: nn.Module, train: bool = False): + """ + Make attentions trainable if train is True + -> Set train_attention = False when finetuning + """ + for layer in traverse_layers(llama_model): + layer.self_attn.train_attention = train + return llama_model + + +def remove_base_attention(llama_model: nn.Module): + """ + Remove teacher attention after distillation (if we keep it) + """ + for layer in traverse_layers(llama_model): + if getattr(layer.self_attn, "base_attn", False): + del layer.self_attn.base_attn + return llama_model + + +def traverse_layers(model: nn.Module, verbose: bool = False): + """ + Return list of model layers + """ + try: + layers = model.model.layers + if verbose: + LOG.info("-> Loading from model.model.layers") + except AttributeError as e: # if base model + if verbose: + LOG.info(e) + try: + layers = model.layers + if verbose: + LOG.info("-> Loading from model.layers") + except AttributeError as e1: # If we make a PEFT model + if verbose: + LOG.info(e1) + layers = model.base_model.model.model.layers + if verbose: + LOG.info("-> Loading from model.base_model.model.model.layers") + return layers + + +def convert_llama_attention( + layer: nn.Module, + attention_config: DictDefault, + layers: list[nn.Module], # list of layers + train_attention: bool = False, + remove_base_attn: bool = True, +): + """ + Converts a single layer's attention layer as specified by attention_config + """ + return get_attention(**attention_config)( + base_attn=layer.self_attn, + layer_idx=layer.self_attn.layer_idx, # Transformers v4.36 + max_layer_idx=len(layers) - 1, + train_attention=train_attention, + remove_base_attn=remove_base_attn, + ) + + +def get_attention(attention_type: str, **kwargs): + """ + Get the linear attention class; either purely linear or linear with sliding window + -> 'linear' == 'lolcats_llama' + -> 'linear and sliding_window' == 'lolcats_llama_window_*' + """ + kwargs["attention_type"] = attention_type + + if attention_type == "lolcats_llama": + from .linear_attention import LolcatsLinearAttention + + return partial(LolcatsLinearAttention, **kwargs) + + elif attention_type == "lolcats_llama_window_tk": + from .linear_attention import LolcatsTKWindowAttention + + return partial(LolcatsTKWindowAttention, **kwargs) + + elif attention_type == "lolcats_llama_window_sw": + from .linear_attention import LolcatsSlidingWindowAttention + + return partial(LolcatsSlidingWindowAttention, **kwargs) + + elif attention_type == "lolcats_llama_window_sw_linear": + from .linear_attention.linear_window_attention_sw_linear import ( + LolcatsLinearSlidingWindowAttention, + ) + + return partial(LolcatsLinearSlidingWindowAttention, **kwargs) + + # Experimental chunked linear attentions below + elif attention_type == "lolcats_long_llama_window_tk": + from .linear_attention import LolcatsTKWindowLongAttention + + return partial(LolcatsTKWindowLongAttention, **kwargs) + + elif attention_type == "lolcats_long_llama_window_sw": + from .linear_attention import LolcatsSlidingWindowLongAttention + + return partial(LolcatsSlidingWindowLongAttention, **kwargs) + + # TK generation build (requires Thunderkittens) + elif attention_type == "lolcats_llama_window_tk_gen": + from .linear_attention import LolcatsWindowAttentionTKGen + + return partial(LolcatsWindowAttentionTKGen, **kwargs) + + else: + LOG.info(f"-> attention_type {attention_type} not handled... returning None") + return None + + +def get_attention_cache(attention_type: str, past_key_values: Any = None): + """ + Determine how we store past keys and values when generating + """ + if attention_type is None: + return past_key_values + + # LOG.info(f'Returning attention cache based on attention_type == {attention_type}') + elif "lolcats_llama_window_tk_gen" in attention_type: + from .linear_attention import LinearAttentionTKWindowGenerationCache + + return LinearAttentionTKWindowGenerationCache() + + elif "llama_window_tk" in attention_type: + from .linear_attention import LinearAttentionTKWindowCache + + return LinearAttentionTKWindowCache() + + elif "llama_window_sw" in attention_type: + from .linear_attention import LinearAttentionSlidingWindowCache + + return LinearAttentionSlidingWindowCache() + + elif "llama_window_sw_linear" in attention_type: + from .linear_attention import LinearAttentionSlidingWindowCache + + return LinearAttentionSlidingWindowCache() + + # TK generation build (requires Thunderkittens) + elif attention_type == "lolcats_llama_window_tk_gen": + from .linear_attention.linear_window_attention_tk_gen import ( + LinearAttentionTKWindowGenerationCache, + ) + + return LinearAttentionTKWindowGenerationCache() + + elif "softmax" in attention_type: + return past_key_values + + else: + from .linear_attention import LinearAttentionState + + return LinearAttentionState() diff --git a/src/axolotl/integrations/lolcats/model/feature_map.py b/src/axolotl/integrations/lolcats/model/feature_map.py new file mode 100644 index 000000000..65081c946 --- /dev/null +++ b/src/axolotl/integrations/lolcats/model/feature_map.py @@ -0,0 +1,336 @@ +""" +Learnable linear attention feature map classes and functions +""" + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def init_feature_map(name: str, mlp: nn.Module, **kwargs): + """ + Initialize feature map final activation for linear attention + """ + return FeatureMap(activation_name=name, mlp=mlp, **kwargs) + + +def init_feature_map_act(name: str, fullspace: bool = True, **kwargs): + """ + Initialize feature map final activation for linear attention + """ + if name == "softmax_dim" and fullspace: + return SoftmaxDim(**kwargs) + elif name == "softmax_dim" and not fullspace: + return SoftmaxDimHalfspace(**kwargs) + elif name == "exp_dim" and fullspace: + return Exp(**kwargs) + elif name == "exp_dim" and not fullspace: + return ExpHalfspace(**kwargs) + elif name == "pos_elu": + return PosELU(**kwargs) + elif name == "relu": + return ReLU(**kwargs) + + else: + raise NotImplementedError + + +def init_learned_kernel(name: str, **kwargs): + """ + Initialize feature map MLP for linear attention + """ + if name == "untied_head_einsum": + return FeatureMapMLP(**kwargs) + elif name == "untied_head_adapter": + return FeatureMapAdapter(**kwargs) + else: + raise NotImplementedError + + +class FeatureMap(nn.Module): + """ + Final 'activation' of feature map. Can probably be combined with + `FeatureMapMLP` below + + Full feature map is like f(xW + b) + -> This is the `f` part + """ + + def __init__( + self, + activation_name: str, + head_dim_idx: int = -1, + eps: float = 1e-12, + mlp: Optional[nn.Module] = None, + fullspace: bool = True, + ): + super().__init__() + self.head_dim_idx = head_dim_idx + self.eps = eps + self.mlp = mlp if mlp is not None else nn.Identity() + self.activation = init_feature_map_act(activation_name, fullspace, eps=eps) + + def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs): + """ + Assume x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x) + + def q_map(self, *args, **kwargs): + """ + Use for inference in case q and k feature maps differ + """ + return self.forward(*args, **kwargs) + + def k_map(self, *args, **kwargs): + """ + Use for inference in case q and k feature maps differ + """ + return self.forward(*args, **kwargs) + + +# ----------------------- +# Feature map activations +# ----------------------- +class FeatureMapAct(nn.Module): + """ + Base class for feature map activations + """ + + def __init__(self, eps: float = 1e-12): + super().__init__() + self.eps = eps + + def forward(self, x: torch.Tensor, *args, **kwargs): + """ + x.shape is (batch_size, n_heads, seq_len, head_dim) + """ + return x + + +class PosELU(FeatureMapAct): + """ + 1 + ELU activation as in https://arxiv.org/abs/2006.16236 + """ + + def forward(self, x: torch.Tensor, *args, **kwargs): + return (1 + F.elu(x)).clamp(min=self.eps) + + +class ReLU(FeatureMapAct): + """ + ReLU activation as in https://arxiv.org/abs/2103.13076 + """ + + def forward(self, x: torch.Tensor, *args, **kwargs): + return F.relu(x).clamp(min=self.eps) + + +class SoftmaxDim(FeatureMapAct): + """ + Softmax activation as in https://arxiv.org/abs/2402.04347 + """ + + def forward(self, x: torch.Tensor, *args, **kwargs): + return torch.cat( + [torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1 + ).clamp(min=self.eps) + + +class SoftmaxDimHalfspace(FeatureMapAct): + """ + Softmax activation as in https://arxiv.org/abs/2402.04347 + """ + + def forward(self, x: torch.Tensor, *args, **kwargs): + return torch.softmax(x, dim=-1).clamp(min=self.eps) + + +class Exp(FeatureMapAct): + """ + Exp activation as in https://arxiv.org/abs/2402.04347 + """ + + def forward(self, x: torch.Tensor, *args, **kwargs): + x_max = torch.amax(x, dim=-1, keepdim=True) + x_min = torch.amin(x, dim=-1, keepdim=True) + return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp( + min=self.eps + ) + + +class ExpHalfspace(FeatureMapAct): + """ + Exp activation as in https://arxiv.org/abs/2402.04347 + """ + + def forward(self, x: torch.Tensor, *args, **kwargs): + x_max = torch.amax(x, dim=-1, keepdim=True) + return torch.exp(x - x_max).clamp(min=self.eps) + + +# ---------------- +# Feature map MLPs +# ---------------- + + +class FeatureMapMLP(nn.Module): + """ + Learnable MLP in feature map. + + Full feature map is like f(xW + b) + -> This is the `W` and (optional) `b` part + """ + + def __init__( + self, + num_heads: int, + head_dim: int, # input dim + feature_dim: int, # output dim + dtype: torch.dtype, + device: torch.device, + skip_connection: bool = False, + bias: bool = False, + zero_init: bool = False, + normal_init: bool = False, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + self.feature_dim = feature_dim + self.dtype = dtype + self.device = device + self.skip_connection = skip_connection + self.bias = bias + self.zero_init = zero_init + self.normal_init = normal_init + self.init_weights_() + + if self.zero_init: # Zero-out weights or set as identity post-initialization + self.zero_init_with_skip_() if self.skip_connection else self.zero_init_() + + if self.normal_init: + with torch.no_grad(): + nn.init.normal_(self.layer) + + if self.skip_connection: + assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}" + assert self.head_dim == self.feature_dim, assertion_fail + + def init_weights_(self): + """ + Initialize (W)eights and (b)iases + """ + self.layer = nn.Parameter( + torch.zeros( + (self.num_heads, self.head_dim, self.feature_dim), + dtype=self.dtype, + device=self.device, + ) + ) + nn.init.kaiming_uniform_(self.layer) + + if self.bias: + self.bias = nn.Parameter( + torch.zeros( + (1, self.num_heads, 1, 1), # self.feature_dim), + dtype=self.dtype, + device=self.device, + ) + ) + nn.init.kaiming_uniform_(self.bias) + else: + self.bias = 0.0 # hack + + def zero_init_with_skip_(self): + """ + Initialize weights to zero matrix if skip connection + """ + with torch.no_grad(): + nn.init.zeros_(self.layer) + + def zero_init_(self): + """ + Initialize weights to identity matrix if no skip connection + """ + with torch.no_grad(): + for i in range(self.layer.shape[0]): + try: + nn.init.eye_(self.layer[i]) + except RuntimeError: + with torch.no_grad(): + dtype = self.layer[i].dtype + weight = torch.eye( + *self.layer[i].shape, + requires_grad=self.layer[i].requires_grad, + device=self.layer[i].device, + ) + self.layer[i] = weight.to(dtype=dtype) + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, num_heads, seq_len, head_dim) + """ + _x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias + return x + _x if self.skip_connection else _x + + +class FeatureMapAdapter(FeatureMapMLP): + """ + Learnable Feature map with bottleneck adapter + as in https://arxiv.org/abs/1902.00751 + + We don't use but could be fun to try + """ + + def __init__(self, hidden_dim: int, *args, **kwargs): + kwargs["skip_connection"] = True + kwargs["bias"] = True + kwargs["zero_init"] = True + self.hidden_dim = hidden_dim + super().__init__(*args, **kwargs) + + def init_weights_(self): + """ + Initialize (W)eights and (b)iases + """ + kwargs = {"dtype": self.dtype, "device": self.device} + self.layer0 = nn.Parameter( + torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs) + ) + self.layer1 = nn.Parameter( + torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs) + ) + nn.init.kaiming_uniform_(self.layer0) + nn.init.kaiming_uniform_(self.layer1) + + self.bias0 = nn.Parameter( + torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs) + ) + self.bias1 = nn.Parameter( + torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs) + ) + nn.init.kaiming_uniform_(self.bias0) + nn.init.kaiming_uniform_(self.bias1) + + def zero_init_with_skip_(self): + with torch.no_grad(): + nn.init.zeros_(self.layer0) + nn.init.zeros_(self.layer1) + nn.init.zeros_(self.bias0) + nn.init.zeros_(self.bias1) + + def zero_init_(self): + raise NotImplementedError + + def forward(self, x: torch.Tensor): + """ + Assume x.shape is (batch_size, num_heads, seq_len, head_dim) + -> Down-project, apply nonlinearity, up-project; add skip connection + """ + _x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0 + _x = F.relu(_x) + _x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1 + return x + _x if self.skip_connection else _x diff --git a/src/axolotl/integrations/lolcats/model/rotary.py b/src/axolotl/integrations/lolcats/model/rotary.py new file mode 100644 index 000000000..ed885dcbc --- /dev/null +++ b/src/axolotl/integrations/lolcats/model/rotary.py @@ -0,0 +1,204 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Rotary embeddings. Same as usual for Transformer models. + +Note these are modified from HF Transformers v4.36, from: +- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py +- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123 +""" +from typing import Optional + +import torch +import torch.nn as nn + + +def get_rotary_embeddings( + rope_scaling_type: Optional[str] = None, + head_dim: int = 128, + max_position_embeddings: int = 4096, + rope_theta: float = 10000.0, + rope_scaling_factor: float = 1.0, + device: Optional[torch.device] = None, +) -> nn.Module: + """Return rotary embedding object""" + if rope_scaling_type is None: + return RotaryEmbedding( + head_dim, + max_position_embeddings=max_position_embeddings, + base=rope_theta, + device=device, + ) + elif rope_scaling_type == "linear": + return LinearScalingRotaryEmbedding( + head_dim, + max_position_embeddings=max_position_embeddings, + scaling_factor=rope_scaling_factor, + base=rope_theta, + device=device, + ) + elif rope_scaling_type == "dynamic": + return DynamicNTKScalingRotaryEmbedding( + head_dim, + max_position_embeddings=max_position_embeddings, + scaling_factor=rope_scaling_factor, + base=rope_theta, + device=device, + ) + else: + raise NotImplementedError( + f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.' + ) + + +# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors.""" + if position_ids is not None: + cos, sin = cos[position_ids], sin[position_ids] + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36) +class RotaryEmbedding(nn.Module): + """Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + """ + Compute rotary embeddings + """ + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers/models/llama/modeling_llama.py at v4.36 +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers/models/llama/modeling_llama.py at v4.36 +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)