feat: add lolcats with fixed typed
This commit is contained in:
201
src/axolotl/integrations/lolcats/LICENSE
Normal file
201
src/axolotl/integrations/lolcats/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright [yyyy] [name of copyright owner]
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Linear and linear attention + sliding window classes
|
||||||
|
"""
|
||||||
|
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
|
||||||
|
from .linear_window_attention_sw import (
|
||||||
|
LinearAttentionSlidingWindowCache,
|
||||||
|
LolcatsSlidingWindowAttention,
|
||||||
|
)
|
||||||
|
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||||
|
from .linear_window_attention_tk import (
|
||||||
|
LinearAttentionTKWindowCache,
|
||||||
|
LolcatsTKWindowAttention,
|
||||||
|
)
|
||||||
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
LolcatsWindowAttentionTKGen,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Experimental chunk linear attentions
|
||||||
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
@@ -0,0 +1,561 @@
|
|||||||
|
"""
|
||||||
|
Linear attention classes
|
||||||
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
|
from typing import Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
|
try:
|
||||||
|
from csrc import causal_dot_product as fast_causal_dot_product
|
||||||
|
except ImportError:
|
||||||
|
fast_causal_dot_product = None
|
||||||
|
|
||||||
|
from ..model.feature_map import init_feature_map, init_learned_kernel
|
||||||
|
from ..model.rotary import apply_rotary_pos_emb, get_rotary_embeddings
|
||||||
|
from .utils import repeat_kv
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# Attention functions
|
||||||
|
# -------------------
|
||||||
|
|
||||||
|
|
||||||
|
def causal_dot_product(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Causal linear attention dot product
|
||||||
|
- If available, use CUDA kernel from fast-transformers
|
||||||
|
"""
|
||||||
|
if fast_causal_dot_product is None:
|
||||||
|
kv = torch.einsum("bhlf,bhld->bhlfd", k, v)
|
||||||
|
return torch.einsum("bhlf,bhlfd->bhld", q, kv.cumsum(dim=2))
|
||||||
|
return fast_causal_dot_product(q, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
def linear_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
fp32_attention: bool = False,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
"""
|
||||||
|
Compute linear attention with CUDA kernel implementation from fast-transformers
|
||||||
|
- https://github.com/idiap/fast-transformers
|
||||||
|
- Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim);
|
||||||
|
v is shape (b, h, l, head_dim)
|
||||||
|
"""
|
||||||
|
dtype = q.dtype
|
||||||
|
# Causal mask already applied
|
||||||
|
y = causal_dot_product(
|
||||||
|
q.contiguous().to(dtype=torch.float32),
|
||||||
|
k.contiguous().to(dtype=torch.float32),
|
||||||
|
v.contiguous().to(dtype=torch.float32),
|
||||||
|
)
|
||||||
|
if fp32_attention:
|
||||||
|
y = (
|
||||||
|
y
|
||||||
|
/ (
|
||||||
|
torch.einsum("bhld,bhld->bhl", q.float(), k.float().cumsum(dim=2)) + eps
|
||||||
|
)[..., None]
|
||||||
|
).to(dtype=dtype)
|
||||||
|
else:
|
||||||
|
y = y.to(dtype=dtype)
|
||||||
|
k = k.float().cumsum(dim=2).to(dtype=dtype)
|
||||||
|
y = y / (torch.einsum("bhld,bhld->bhl", q, k) + eps)[..., None]
|
||||||
|
return y, None, None
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: Optional[torch.Tensor] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
fp32_attention: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Standard softmax attention; only compute outputs if v is not None
|
||||||
|
-> Assume q, k, v are shape (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
y = None
|
||||||
|
a = torch.einsum("bhmd,bhnd->bhmn", q, k) * (k.shape[-1] ** -0.5)
|
||||||
|
if causal: # Apply causal mask
|
||||||
|
m, n = a.shape[-2:]
|
||||||
|
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||||
|
n - m + 1
|
||||||
|
)
|
||||||
|
a = a.masked_fill(causal_mask, -torch.finfo(a.dtype).max)
|
||||||
|
if fp32_attention:
|
||||||
|
a = torch.softmax(a, dim=-1, dtype=torch.float32).to(q.dtype)
|
||||||
|
else:
|
||||||
|
a = torch.softmax(a, dim=-1)
|
||||||
|
if v is not None:
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||||
|
return y, a, None
|
||||||
|
|
||||||
|
|
||||||
|
def quadratic_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: Optional[torch.Tensor] = None,
|
||||||
|
causal: bool = True,
|
||||||
|
fp32_attention: bool = False,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute attention with feature maps by instantiating L x L matrix of attention weights
|
||||||
|
-> Use for attention distillation
|
||||||
|
-> Assume q, k are shape (batch_size, num_heads, seq_len, feature_dim); v is shape (b, h, l, head_dim)
|
||||||
|
"""
|
||||||
|
y = None
|
||||||
|
dtype = q.dtype
|
||||||
|
if fp32_attention:
|
||||||
|
q, k = q.float(), k.float()
|
||||||
|
a = torch.einsum("bhmd,bhnd->bhmn", q, k) # note we don't scale, tho we could
|
||||||
|
if causal: # Apply causal mask
|
||||||
|
m, n = a.shape[-2:]
|
||||||
|
causal_mask = torch.ones((m, n), device=a.device, dtype=torch.bool).triu(
|
||||||
|
n - m + 1
|
||||||
|
)
|
||||||
|
a = a.masked_fill(causal_mask, 0)
|
||||||
|
# Normalize to compute attention
|
||||||
|
a = a / (a.sum(dim=-1, keepdim=True) + eps)
|
||||||
|
a = a.to(dtype=dtype) if fp32_attention else a
|
||||||
|
if torch.isnan(a).sum() > 0:
|
||||||
|
breakpoint()
|
||||||
|
if v is not None:
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a, v)
|
||||||
|
return y, a, None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsLinearAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
LoLCATs attention implementation initialized from a
|
||||||
|
`LlamaAttention` or `MistralAttention` object (base_attn)
|
||||||
|
|
||||||
|
Most of the arguments are directly tied to argparse args
|
||||||
|
- For now we don't support padding.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_attn: nn.Module, # like LlamaAttention
|
||||||
|
feature_map: str,
|
||||||
|
feature_map_kwargs: dict,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
max_layer_idx: Optional[int] = None,
|
||||||
|
learned_kernel: Optional[str] = None,
|
||||||
|
learned_kernel_kwargs: Optional[dict] = None,
|
||||||
|
tie_qk_kernels: Optional[bool] = False,
|
||||||
|
rotary_config: Optional[dict] = None,
|
||||||
|
train_attention: Optional[bool] = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
attention_type: Optional[str] = "lolcats_llama",
|
||||||
|
mask_value: int = 0,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
fp32_attention: bool = False,
|
||||||
|
track_state_grads: bool = False,
|
||||||
|
rank: Optional[int] = 0,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.base_config = getattr(base_attn, "config", None)
|
||||||
|
if self.base_config is not None:
|
||||||
|
self.base_config = self.base_config.to_dict()
|
||||||
|
self.attention_type = attention_type
|
||||||
|
self.mask_value = mask_value
|
||||||
|
self.eps = eps
|
||||||
|
self.layer_idx = layer_idx if layer_idx is not None else base_attn.layer_idx
|
||||||
|
self.max_layer_idx = max_layer_idx
|
||||||
|
self.tie_qk_kernels = tie_qk_kernels
|
||||||
|
self.train_attention = train_attention
|
||||||
|
self.base_inference = False
|
||||||
|
self.fp32_attention = fp32_attention
|
||||||
|
self.track_state_grads = track_state_grads
|
||||||
|
if rank == 0: # multi-gpu
|
||||||
|
if fp32_attention and layer_idx == 0:
|
||||||
|
print(f"-> fp32_attention is {fp32_attention}")
|
||||||
|
if layer_idx == 0 and feature_map_kwargs is not None:
|
||||||
|
for k, v in feature_map_kwargs.items():
|
||||||
|
print(f"-> {k}: {v}")
|
||||||
|
if layer_idx == 0 and learned_kernel_kwargs is not None:
|
||||||
|
for k, v in learned_kernel_kwargs.items():
|
||||||
|
print(f"-> {k}: {v}")
|
||||||
|
|
||||||
|
self.remove_base_attn = remove_base_attn
|
||||||
|
|
||||||
|
# Rotary embeddings (patch for Llama 3.1, Transformer v4.43.0)
|
||||||
|
self.rotary_config = rotary_config
|
||||||
|
# if isinstance(self.rotary_config, DictDefault):
|
||||||
|
# self.rotary_config = OmegaConf.to_container(self.rotary_config)
|
||||||
|
|
||||||
|
self.rotary_emb = None
|
||||||
|
if self.base_config is not None and self.rotary_config is None:
|
||||||
|
self.rotary_emb = base_attn.rotary_emb
|
||||||
|
|
||||||
|
self.init_weights_(base_attn, remove_base_attn)
|
||||||
|
self.init_feature_map_(
|
||||||
|
feature_map, feature_map_kwargs, learned_kernel, learned_kernel_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_feature_map_(
|
||||||
|
self,
|
||||||
|
feature_map: str,
|
||||||
|
feature_map_kwargs: dict,
|
||||||
|
learned_kernel: Optional[str] = None,
|
||||||
|
learned_kernel_kwargs: Optional[dict] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize MLP-based feature map
|
||||||
|
"""
|
||||||
|
self.fmap_gqa = False # Turn True if specified below
|
||||||
|
if learned_kernel is not None and learned_kernel_kwargs is not None:
|
||||||
|
# Ensure dict
|
||||||
|
learned_kernel_kwargs = {k: v for k, v in learned_kernel_kwargs.items()}
|
||||||
|
learned_kernel_kwargs["num_heads"] = self.num_heads
|
||||||
|
learned_kernel_kwargs["head_dim"] = self.head_dim
|
||||||
|
learned_kernel_kwargs["dtype"] = self.q_proj.weight.dtype
|
||||||
|
learned_kernel_kwargs["device"] = self.q_proj.weight.device
|
||||||
|
# Create MLP
|
||||||
|
mlp_learned_kernel = init_learned_kernel(
|
||||||
|
learned_kernel, **learned_kernel_kwargs
|
||||||
|
)
|
||||||
|
# Add "activation"; see src.models.feature_map.py
|
||||||
|
self.feature_map_q = init_feature_map(
|
||||||
|
name=feature_map, mlp=mlp_learned_kernel, **feature_map_kwargs
|
||||||
|
)
|
||||||
|
if self.tie_qk_kernels: # tie mlp weights for query and key feature maps
|
||||||
|
self.feature_map_k = self.feature_map_q
|
||||||
|
else:
|
||||||
|
self.feature_map_k = copy.deepcopy(self.feature_map_q)
|
||||||
|
|
||||||
|
def init_weights_(self, base_attn: nn.Module, remove_base_attn: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize module layers, weights, positional dependencies, etc.
|
||||||
|
from original softmax attention layer (base_attn)
|
||||||
|
"""
|
||||||
|
# Make other attributes accessible
|
||||||
|
self.attention_dropout = 0 # We don't use dropout
|
||||||
|
self.hidden_size = base_attn.hidden_size
|
||||||
|
self.num_heads = base_attn.num_heads
|
||||||
|
self.head_dim = base_attn.head_dim
|
||||||
|
self.num_key_value_heads = base_attn.num_key_value_heads
|
||||||
|
self.num_key_value_groups = base_attn.num_key_value_groups
|
||||||
|
|
||||||
|
self.q_shape = [self.num_heads, self.head_dim]
|
||||||
|
self.k_shape = [self.num_key_value_heads, self.head_dim]
|
||||||
|
self.v_shape = [self.num_key_value_heads, self.head_dim]
|
||||||
|
device = base_attn.q_proj.weight.device
|
||||||
|
# Rotary embeddings
|
||||||
|
if self.rotary_emb is None:
|
||||||
|
self.max_position_embeddings = base_attn.max_position_embeddings
|
||||||
|
scaling_factor = getattr(base_attn.rotary_emb, "scaling_factor", 1.0)
|
||||||
|
if self.rotary_config is None:
|
||||||
|
self.rotary_emb = get_rotary_embeddings(
|
||||||
|
rope_scaling_type=None,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings, # base_attn.rotary_emb.max_position_embeddings,
|
||||||
|
rope_theta=base_attn.rotary_emb.base,
|
||||||
|
rope_scaling_factor=scaling_factor, # base_attn.rotary_emb.scaling_factor,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if "device" not in self.rotary_config:
|
||||||
|
self.rotary_config["device"] = device
|
||||||
|
self.rotary_emb = get_rotary_embeddings(**self.rotary_config)
|
||||||
|
|
||||||
|
# Copy original model projection layers
|
||||||
|
self.q_proj = base_attn.q_proj
|
||||||
|
self.k_proj = base_attn.k_proj
|
||||||
|
self.v_proj = base_attn.v_proj
|
||||||
|
self.o_proj = base_attn.o_proj
|
||||||
|
try: # If wanting to use FA2 for ground-truth inference
|
||||||
|
self._flash_attn_uses_top_left_mask = (
|
||||||
|
base_attn._flash_attn_uses_top_left_mask
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self.remove_base_attn or remove_base_attn:
|
||||||
|
del base_attn # We don't need to keep these around
|
||||||
|
else:
|
||||||
|
self.base_attn = base_attn # For some training runs helpful to just call
|
||||||
|
|
||||||
|
def process_qkv(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Any] = None,
|
||||||
|
): # "legacy" cache approach
|
||||||
|
"""
|
||||||
|
Compute queries, keys, and values
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q = self.q_proj(hidden_states)
|
||||||
|
k = self.k_proj(hidden_states)
|
||||||
|
v = self.v_proj(hidden_states)
|
||||||
|
kv_seq_len = k.shape[-2]
|
||||||
|
|
||||||
|
# Shape is (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
q = q.view(b, l, *self.q_shape).transpose(1, 2)
|
||||||
|
k = k.view(b, l, *self.k_shape).transpose(1, 2)
|
||||||
|
v = v.view(b, l, *self.v_shape).transpose(1, 2)
|
||||||
|
|
||||||
|
if (
|
||||||
|
past_key_value is not None
|
||||||
|
): # and k.shape[2] > q.shape[2]: # e.g., when generating
|
||||||
|
past_key_value.window_size = getattr(
|
||||||
|
self, "decode_window_size", None
|
||||||
|
) # self.decode_window_size
|
||||||
|
if isinstance(
|
||||||
|
past_key_value, Cache
|
||||||
|
): # In Transformers v4.36+ this is a DynamicCache object
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(
|
||||||
|
kv_seq_len, self.layer_idx
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
|
||||||
|
# Apply rotary embeddings and repeat for GQA
|
||||||
|
if position_ids is not None and kv_seq_len <= position_ids[0, -1]:
|
||||||
|
kv_seq_len = position_ids[0, -1] + 1 # hack for adjusting position ids
|
||||||
|
|
||||||
|
if self.rotary_emb is None:
|
||||||
|
raise ValueError("Rotary embeddings not initialized")
|
||||||
|
|
||||||
|
try: # As in Transformers v4.36
|
||||||
|
cos, sin = self.rotary_emb(k, seq_len=kv_seq_len)
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
|
||||||
|
except TypeError: # As in Transformers v4.39+
|
||||||
|
cos, sin = self.rotary_emb(v, position_ids)
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
|
|
||||||
|
k = repeat_kv(k, self.num_key_value_groups)
|
||||||
|
v = repeat_kv(v, self.num_key_value_groups)
|
||||||
|
return q, k, v, kv_seq_len
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Any] = None, # "legacy" cache approach
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass modified from transformers.models.mistral.modeling_mistral (v4.36)
|
||||||
|
- Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.base_inference:
|
||||||
|
with torch.no_grad():
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
y_true, _, _ = softmax_attention(q, k, v, causal=True)
|
||||||
|
y_true = (
|
||||||
|
y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
attn_weights = (None, None)
|
||||||
|
|
||||||
|
elif self.train_attention: # Distilling / learning attentions
|
||||||
|
# Note for now we assume no padding when distilling; attention masks only enforce causality
|
||||||
|
assert (
|
||||||
|
output_attentions is True
|
||||||
|
), f"When training feature maps, output_attentions should be True but is {output_attentions}"
|
||||||
|
with torch.no_grad():
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
_y_true, attn_true, _ = softmax_attention(q, k, v, causal=True)
|
||||||
|
y_true = (
|
||||||
|
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
|
||||||
|
# 2. Compute "predicted" attention (just weights)
|
||||||
|
q, k = self.feature_map_q.q_map(q), self.feature_map_k.k_map(k)
|
||||||
|
y_pred, attn_pred, _ = quadratic_attention(q, k, v, causal=True)
|
||||||
|
attn_weights = ( # type: ignore
|
||||||
|
(attn_pred, attn_true),
|
||||||
|
(y_pred, _y_true),
|
||||||
|
) # Save both attention weights so we can supervise.
|
||||||
|
|
||||||
|
else: # Finetuning
|
||||||
|
q, k = self.feature_map_q(q), self.feature_map_k(k)
|
||||||
|
# Apply prefill mask
|
||||||
|
if attention_mask is not None and q.shape[2] > 1:
|
||||||
|
if len(attention_mask.shape) == 4:
|
||||||
|
lin_attn_mask = (attention_mask == 0)[:, :1, -1, :l][
|
||||||
|
..., None
|
||||||
|
] # b, 1, k_len, 1
|
||||||
|
else:
|
||||||
|
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
|
||||||
|
k = k.masked_fill(~lin_attn_mask, 0)
|
||||||
|
|
||||||
|
if past_key_value is not None: # Initialize states
|
||||||
|
if len(past_key_value.kv_states) == self.layer_idx:
|
||||||
|
b, h, _, f = k.shape
|
||||||
|
past_key_value.kv_states.append(
|
||||||
|
torch.zeros(
|
||||||
|
b, h, f, self.head_dim, dtype=q.dtype, device=q.device
|
||||||
|
)
|
||||||
|
)
|
||||||
|
past_key_value.k_states.append(
|
||||||
|
torch.zeros(b, h, 1, f, dtype=q.dtype, device=q.device)
|
||||||
|
)
|
||||||
|
# Generating
|
||||||
|
if q.shape[2] == 1 and kv_seq_len > 1 and past_key_value is not None:
|
||||||
|
assert use_cache is True
|
||||||
|
kv_state, k_state = past_key_value.update(
|
||||||
|
k, v, self.layer_idx, accumulate_in_fp32=self.fp32_attention
|
||||||
|
)
|
||||||
|
if self.fp32_attention:
|
||||||
|
q = q.float()
|
||||||
|
y_true = (
|
||||||
|
torch.einsum("bhlf,bhfd->bhld", q, kv_state.float())
|
||||||
|
/ torch.einsum("bhlf,bhlf->bhl", q, k_state.float())[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
).to(dtype=k.dtype)
|
||||||
|
else:
|
||||||
|
y_true = (
|
||||||
|
torch.einsum("bhlf,bhfd->bhld", q, kv_state)
|
||||||
|
/ torch.einsum("bhlf,bhlf->bhl", q, k_state)[..., None]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
y_true, _, _ = linear_attention(
|
||||||
|
q, k, v, self.fp32_attention, self.eps
|
||||||
|
) # Ordinarily the states are ignored
|
||||||
|
past_key_value.update(
|
||||||
|
k.detach(),
|
||||||
|
v.detach(),
|
||||||
|
self.layer_idx,
|
||||||
|
accumulate_in_fp32=self.fp32_attention,
|
||||||
|
)
|
||||||
|
# doing some unnecessary recomputation here
|
||||||
|
else:
|
||||||
|
y_true, _, _ = linear_attention(q, k, v, self.fp32_attention, self.eps)
|
||||||
|
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionState(Cache):
|
||||||
|
"""
|
||||||
|
Handle the KV and K states for linear attention
|
||||||
|
- Adopts HF Transformers `past_key_values` convention
|
||||||
|
- Inherits from `Cache` class
|
||||||
|
- Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""
|
||||||
|
Returns the sequence length of the cached states. A layer index can be optionally passed.
|
||||||
|
"""
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
if len(self._seen_tokens_by_layer) <= layer_idx: # Initializing kv and k states
|
||||||
|
self._seen_tokens_by_layer.append(0)
|
||||||
|
return self._seen_tokens_by_layer[layer_idx]
|
||||||
|
|
||||||
|
def get_max_length(self) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_usable_length(
|
||||||
|
self, new_seq_length: int, layer_idx: Optional[int] = 0
|
||||||
|
) -> int:
|
||||||
|
"""Given the sequence length of the new inputs, returns the usable length of the cache."""
|
||||||
|
# Cache without size limit -> all cache is usable
|
||||||
|
# Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
|
||||||
|
# length, we will need to evict part of the cache (and thus not all cache is usable)
|
||||||
|
max_length = self.get_max_length()
|
||||||
|
previous_seq_length = self.get_seq_length(layer_idx)
|
||||||
|
if max_length is not None and previous_seq_length + new_seq_length > max_length:
|
||||||
|
return max_length - new_seq_length
|
||||||
|
return previous_seq_length
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
key_states, value_states = key_states.float(), value_states.float()
|
||||||
|
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", key_states, value_states
|
||||||
|
).detach()
|
||||||
|
k_state = key_states.sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
).detach() # b, h, 1, f; note the 1
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
print(
|
||||||
|
"if len(self.k_states) <= layer_idx: # Initializing kv and k states"
|
||||||
|
)
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
else:
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def to_legacy_cache(self):
|
||||||
|
"""Hack, but just return self"""
|
||||||
|
return self
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""
|
||||||
|
Reorders the cache for beam search, given the selected beam indices.
|
||||||
|
-> Copied from transformers/src/transformers/cache_utils.py
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Reordering cache not implemented for LinearAttentionState"
|
||||||
|
)
|
||||||
@@ -0,0 +1,460 @@
|
|||||||
|
"""
|
||||||
|
Subquadratic attention combining sliding window and linear attentions
|
||||||
|
- Using "standard" sliding windows
|
||||||
|
- Didactically computes outputs with n^2 attention weights for now
|
||||||
|
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
from .linear_attention import (
|
||||||
|
LinearAttentionState,
|
||||||
|
LolcatsLinearAttention,
|
||||||
|
softmax_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Sliding window helpers
|
||||||
|
# ----------------------
|
||||||
|
def get_masks(
|
||||||
|
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Return masks for softmax and linear attention terms
|
||||||
|
-> 1 is include, 0 is ignore
|
||||||
|
"""
|
||||||
|
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
k_len - q_len
|
||||||
|
)
|
||||||
|
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
k_len - q_len - window_size
|
||||||
|
)
|
||||||
|
window_mask = causal_mask - linear_mask
|
||||||
|
# Return softmax mask (window), linear attention mask
|
||||||
|
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||||
|
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_quadratic(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hybrid attention combining sliding window and linear attentions
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_window, mask_linear = get_masks(
|
||||||
|
window_size, q.shape[-2], k.shape[-2], q.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||||
|
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||||
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||||
|
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||||
|
# Allow outputs to also depend on prior kv_state and k_state
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||||
|
if (
|
||||||
|
kv_state is not None and k_state is not None
|
||||||
|
): # Combine with prior kv_state and k_state
|
||||||
|
y += linear_factor * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln += (
|
||||||
|
linear_factor
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||||
|
)
|
||||||
|
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
return y, a # attention weights only for the last chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
class LolcatsSlidingWindowAttention(LolcatsLinearAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 64,
|
||||||
|
decode_window_size: Optional[int] = None,
|
||||||
|
affine_attention_factors: bool = False,
|
||||||
|
init_window_factor: float = 0,
|
||||||
|
train_window_factor: bool = True,
|
||||||
|
state_grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.decode_window_size = (
|
||||||
|
decode_window_size if decode_window_size is not None else window_size
|
||||||
|
)
|
||||||
|
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_sw'
|
||||||
|
# Determine how we compute attentions
|
||||||
|
self.quadratic_attention = hybrid_attention_quadratic
|
||||||
|
self.attention_type = kwargs[
|
||||||
|
"attention_type"
|
||||||
|
] # 'hedgehog_long_llama_window_sw'
|
||||||
|
# Learnable factor for combining attentions
|
||||||
|
self.affine_attention_factors = affine_attention_factors
|
||||||
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||||
|
if train_window_factor:
|
||||||
|
self.window_factors = nn.Parameter(
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"window_factors",
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||||
|
self.base_inference = False
|
||||||
|
self.state_grad_enabled = state_grad_enabled
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||||
|
k
|
||||||
|
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
with torch.no_grad():
|
||||||
|
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||||
|
y_true = (
|
||||||
|
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
|
||||||
|
# 2. Compute "predicted" attention outputs
|
||||||
|
# compute attn weights under sliding window
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||||
|
else:
|
||||||
|
attn_weights = None
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = a_pred
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if (
|
||||||
|
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||||
|
): # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum(
|
||||||
|
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||||
|
) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum(
|
||||||
|
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||||
|
)[..., None]
|
||||||
|
)
|
||||||
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, _ = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.layer_idx,
|
||||||
|
fmap_key_states=f_k,
|
||||||
|
accumulate_in_fp32=True,
|
||||||
|
)
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Account for sliding windows
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = False,
|
||||||
|
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||||
|
grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Update KV, K states; and KV cache during training
|
||||||
|
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||||
|
up to sliding window terms
|
||||||
|
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||||
|
up to end of sequence
|
||||||
|
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||||
|
"""
|
||||||
|
if fmap_key_states is None:
|
||||||
|
raise ValueError("fmap_key_states must not be None")
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
# key_states = key_states.float()
|
||||||
|
fmap_key_states = fmap_key_states.float()
|
||||||
|
value_states = value_states.float()
|
||||||
|
|
||||||
|
# Decoding KV state (KV terms up to last window_size)
|
||||||
|
decode_kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, : -self.window_size],
|
||||||
|
value_states[:, :, : -self.window_size],
|
||||||
|
)
|
||||||
|
# KV state
|
||||||
|
kv_state = decode_kv_state + torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, -self.window_size :],
|
||||||
|
value_states[:, :, -self.window_size :],
|
||||||
|
)
|
||||||
|
# shape is b, h, 1, f; note the 1
|
||||||
|
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
|
||||||
|
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||||
|
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||||
|
|
||||||
|
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(
|
||||||
|
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||||
|
)
|
||||||
|
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||||
|
else:
|
||||||
|
# Update kv and k states recurrently
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
|
||||||
|
decode_kv_state = (
|
||||||
|
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||||
|
+ decode_kv_state
|
||||||
|
).to(dtype)
|
||||||
|
decode_k_state = (
|
||||||
|
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||||
|
).to(dtype)
|
||||||
|
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||||
|
self.decode_k_states[layer_idx] = decode_k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||||
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
feature_map_k: Callable,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the decoding KV and K states, and KV cache, during decodeing
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||||
|
else:
|
||||||
|
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||||
|
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||||
|
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||||
|
# else:
|
||||||
|
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||||
|
).to(
|
||||||
|
dtype
|
||||||
|
) # b, h, f, d
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat(
|
||||||
|
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||||
|
)
|
||||||
|
self.v_cache[layer_idx] = torch.cat(
|
||||||
|
[v_cache[:, :, 1:, :], values], dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += keys.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
@@ -0,0 +1,687 @@
|
|||||||
|
"""
|
||||||
|
Subquadratic attention combining sliding window and linear attentions
|
||||||
|
- Using "standard" sliding windows
|
||||||
|
- Didactically computes outputs with n^2 attention weights for now
|
||||||
|
- Copied + adapted from linear_window_attention_tk.py for single-file reference
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
|
from ..model.rotary import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
|
from .linear_attention import (
|
||||||
|
LinearAttentionState,
|
||||||
|
LolcatsLinearAttention,
|
||||||
|
causal_dot_product,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG = logging.getLogger(
|
||||||
|
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Sliding window helpers
|
||||||
|
# ----------------------
|
||||||
|
def get_masks(
|
||||||
|
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Return masks for softmax and linear attention terms
|
||||||
|
-> 1 is include, 0 is ignore
|
||||||
|
"""
|
||||||
|
causal_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
max(k_len - q_len, 0)
|
||||||
|
)
|
||||||
|
linear_mask = torch.ones((q_len, k_len), device=device, dtype=torch.int).tril(
|
||||||
|
max(k_len - q_len, 0) - window_size
|
||||||
|
)
|
||||||
|
window_mask = causal_mask - linear_mask
|
||||||
|
# Return softmax mask (window), linear attention mask
|
||||||
|
# -> shapes broadcast over (b, h, q_len, k_len)
|
||||||
|
return window_mask[None, None, ...], linear_mask[None, None, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_quadratic(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hybrid attention combining sliding window and linear attentions
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_window, mask_linear = get_masks(
|
||||||
|
window_size, q.shape[-2], k.shape[-2], q.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||||
|
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||||
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||||
|
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||||
|
# Allow outputs to also depend on prior kv_state and k_state
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||||
|
if (
|
||||||
|
kv_state is not None and k_state is not None
|
||||||
|
): # Combine with prior kv_state and k_state
|
||||||
|
y += linear_factor * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln += (
|
||||||
|
linear_factor
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||||
|
)
|
||||||
|
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
return y, a # attention weights only for the last chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------
|
||||||
|
# Hybrid window attention linear
|
||||||
|
# ------------------------------
|
||||||
|
def under_window_linear_attention(
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
):
|
||||||
|
"""Compute hybrid window attention dot product with linear complexity in q_len"""
|
||||||
|
dtype = f_q.dtype
|
||||||
|
w = window_size
|
||||||
|
f_k = F.pad(f_k, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||||
|
v = F.pad(v, (0, 0, w, 0), value=0)[:, :, :-w, :]
|
||||||
|
qkv = linear_factor * causal_dot_product(
|
||||||
|
f_q.contiguous().to(dtype=torch.float32),
|
||||||
|
f_k.contiguous().to(dtype=torch.float32),
|
||||||
|
v.contiguous().to(dtype=torch.float32),
|
||||||
|
).to(dtype=dtype)
|
||||||
|
sum_f_k = f_k.float().cumsum(dim=2).to(dtype=dtype)
|
||||||
|
sum_qk = linear_factor * torch.einsum("bhld,bhld->bhl", f_q, sum_f_k)[..., None]
|
||||||
|
sum_qk[sum_qk == 0] += eps
|
||||||
|
return qkv, sum_qk
|
||||||
|
|
||||||
|
|
||||||
|
def sliding_window_softmax_attention(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Compute sliding window softmax attention without materializing
|
||||||
|
O(seq_len^2) attention weights
|
||||||
|
"""
|
||||||
|
d = q.shape[-1]
|
||||||
|
# Compute windows for keys
|
||||||
|
window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
k = F.pad(k, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||||
|
v = F.pad(v, (0, 0, window_size - 1, 0), value=0).unfold(**window_kwargs)
|
||||||
|
|
||||||
|
# Compute windowed_softmax(qk); causal in its construction
|
||||||
|
a_sm = torch.einsum("bhld,bhldw->bhlw", q, k) * (d**-0.5)
|
||||||
|
a_sm[a_sm == 0] = -torch.finfo(
|
||||||
|
q.dtype
|
||||||
|
).max # heuristic for zeroing out padding above
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
return torch.einsum("bhlw,bhldw->bhld", a_sm, v), sum_sm
|
||||||
|
# return torch.einsum('bhlw,bhldw->bhld', torch.softmax(qk, dim=-1), v)
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_linear(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: Optional[torch.Tensor] = None,
|
||||||
|
linear_factor: Optional[torch.Tensor] = None,
|
||||||
|
window_size: int = 64,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Alternative hybrid attention combining sliding window and linear attentions
|
||||||
|
-> Uses O(n) memory if n is sequence length by padding and unfolding windows
|
||||||
|
"""
|
||||||
|
# window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
if window_factor is None:
|
||||||
|
raise ValueError("window_factor must be provided")
|
||||||
|
|
||||||
|
if linear_factor is None:
|
||||||
|
raise ValueError("linear_factor must be provided")
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
with torch.no_grad():
|
||||||
|
qkv_sm, sum_qk_sm = sliding_window_softmax_attention(
|
||||||
|
q, k, v, window_size, window_factor, mask_value
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
qkv_ln, sum_qk_ln = under_window_linear_attention(
|
||||||
|
f_q, f_k, v, window_size, linear_factor, eps
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
y = (qkv_sm + qkv_ln) / (sum_qk_sm + sum_qk_ln)
|
||||||
|
return y, None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
class LolcatsLinearSlidingWindowAttention(LolcatsLinearAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 64,
|
||||||
|
decode_window_size: Optional[int] = None,
|
||||||
|
affine_attention_factors: bool = False,
|
||||||
|
init_window_factor: float = 0,
|
||||||
|
train_window_factor: bool = True,
|
||||||
|
state_grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.decode_window_size = (
|
||||||
|
decode_window_size if decode_window_size is not None else window_size
|
||||||
|
)
|
||||||
|
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# Determine how we compute attentions
|
||||||
|
self.linear_attention = hybrid_attention_linear
|
||||||
|
self.attention_type = "lolcats_llama_window_sw"
|
||||||
|
# Learnable factor for combining attentions
|
||||||
|
self.affine_attention_factors = affine_attention_factors
|
||||||
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||||
|
if train_window_factor:
|
||||||
|
self.window_factors = nn.Parameter(
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"window_factors",
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||||
|
self.base_inference = False
|
||||||
|
self.state_grad_enabled = state_grad_enabled
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
|
||||||
|
if self.train_attention and self.base_inference:
|
||||||
|
with torch.no_grad():
|
||||||
|
_y_true = flash_attention_2(
|
||||||
|
self, # self.base_attn,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
use_cache=False,
|
||||||
|
)[0]
|
||||||
|
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
# layer_io = (hidden_states, _y_true) # hack
|
||||||
|
layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||||
|
return y_true, layer_io, None
|
||||||
|
|
||||||
|
else:
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||||
|
k
|
||||||
|
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||||
|
|
||||||
|
attn_weights = None
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, a_pred = self.linear_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = a_pred
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if (
|
||||||
|
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||||
|
): # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum(
|
||||||
|
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||||
|
) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum(
|
||||||
|
"bhlf,bhnf->bhl", f_q.float(), f_k_state.float()
|
||||||
|
)[..., None]
|
||||||
|
)
|
||||||
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, _ = self.linear_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.layer_idx,
|
||||||
|
fmap_key_states=f_k,
|
||||||
|
accumulate_in_fp32=True,
|
||||||
|
)
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
_y_true = y_true.transpose(1, 2).contiguous()
|
||||||
|
y_true = self.o_proj(_y_true.view(b, l, self.hidden_size))
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
attn_weights = _y_true # flash_attn outputs are shape (b, l, h, d)
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionSlidingWindowCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Account for sliding windows
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = False,
|
||||||
|
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||||
|
grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Update KV, K states; and KV cache during training
|
||||||
|
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||||
|
up to sliding window terms
|
||||||
|
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||||
|
up to end of sequence
|
||||||
|
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||||
|
"""
|
||||||
|
if fmap_key_states is None:
|
||||||
|
raise ValueError("fmap_key_states must not be None")
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("Layer index must not be None")
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
# key_states = key_states.float()
|
||||||
|
fmap_key_states = fmap_key_states.float()
|
||||||
|
value_states = value_states.float()
|
||||||
|
|
||||||
|
# Decoding KV state (KV terms up to last window_size)
|
||||||
|
decode_kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, : -self.window_size],
|
||||||
|
value_states[:, :, : -self.window_size],
|
||||||
|
)
|
||||||
|
# KV state
|
||||||
|
kv_state = decode_kv_state + torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, -self.window_size :],
|
||||||
|
value_states[:, :, -self.window_size :],
|
||||||
|
)
|
||||||
|
# shape is b, h, 1, f; note the 1
|
||||||
|
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
|
||||||
|
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||||
|
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||||
|
|
||||||
|
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(
|
||||||
|
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||||
|
)
|
||||||
|
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||||
|
else:
|
||||||
|
# Update kv and k states recurrently
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
|
||||||
|
decode_kv_state = (
|
||||||
|
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||||
|
+ decode_kv_state
|
||||||
|
).to(dtype)
|
||||||
|
decode_k_state = (
|
||||||
|
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||||
|
).to(dtype)
|
||||||
|
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||||
|
self.decode_k_states[layer_idx] = decode_k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||||
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
feature_map_k: Callable,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the decoding KV and K states, and KV cache, during decodeing
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||||
|
else:
|
||||||
|
# MZ 6/3: handle short inputs; zero-out padding when initial k.shape[2] < self.window_size
|
||||||
|
# if k_cache[:, :, :1, :].sum() == 0: # heuristic for zeroing out padding in cache
|
||||||
|
# f_k_state = torch.zeros(k_cache[:, :, :1, :].shape, dtype=dtype, device=k_cache.device)
|
||||||
|
# else:
|
||||||
|
# f_k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
# -> MZ (later): above only relevant if we zero-pad in our hybrid attention computation
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||||
|
).to(
|
||||||
|
dtype
|
||||||
|
) # b, h, f, d
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat(
|
||||||
|
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||||
|
)
|
||||||
|
self.v_cache[layer_idx] = torch.cat(
|
||||||
|
[v_cache[:, :, 1:, :], values], dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += keys.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Flash Attention 2
|
||||||
|
# -----------------
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper for LlamaFlashAttention2
|
||||||
|
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||||
|
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||||
|
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||||
|
"""
|
||||||
|
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
try: # As in Transformers v4.36
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
except Exception: # As in Transformers v4.39
|
||||||
|
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
if getattr(self, "_flash_attention_forward", False):
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0, # dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
return attn_output, past_key_value
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
LoLCATs attention combining sliding window and linear attentions
|
||||||
|
- Using standard sliding window arrangement
|
||||||
|
- Training over long sequences with fixed memory with recurrent view
|
||||||
|
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
from .linear_window_attention_sw import hybrid_attention_quadratic
|
||||||
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsSlidingWindowLongAttention(LolcatsTKWindowLongAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, remove_base_attn=True, **kwargs):
|
||||||
|
# keep self.base_attn for Flash Attention inference
|
||||||
|
super().__init__(remove_base_attn=True, **kwargs)
|
||||||
|
self.quadratic_attention = hybrid_attention_quadratic
|
||||||
@@ -0,0 +1,466 @@
|
|||||||
|
"""
|
||||||
|
Subquadratic attention combining sliding window and linear attentions
|
||||||
|
- Using the TK "terracing" arrangement
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Any, Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
from .linear_attention import (
|
||||||
|
LinearAttentionState,
|
||||||
|
LolcatsLinearAttention,
|
||||||
|
softmax_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------------
|
||||||
|
# Sliding window helpers
|
||||||
|
# ----------------------
|
||||||
|
def get_masks(
|
||||||
|
window_size: int, q_len: int, k_len: int, device: torch.device
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Return masks for softmax and linear attention terms
|
||||||
|
-> 1 is include, 0 is ignore
|
||||||
|
"""
|
||||||
|
win_len = window_size
|
||||||
|
m = math.ceil(max(q_len, k_len) / window_size)
|
||||||
|
# Creates an n x n mask where n = window_size^2
|
||||||
|
mask = torch.block_diag(
|
||||||
|
*[
|
||||||
|
torch.ones(
|
||||||
|
(win_len, win_len),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
* m
|
||||||
|
)
|
||||||
|
mask += torch.roll(mask, -win_len, -1) # this adds the terracing
|
||||||
|
if mask.shape[0] > q_len:
|
||||||
|
mask = mask[-q_len:]
|
||||||
|
if mask.shape[1] > k_len:
|
||||||
|
mask = mask[:, -k_len:]
|
||||||
|
# Return softmax mask (window), linear attention mask
|
||||||
|
mask = mask[None, None, ...] # b, h, q_len, k_len
|
||||||
|
return (
|
||||||
|
torch.tril(mask).to(device=device, dtype=torch.int),
|
||||||
|
torch.tril(1 - mask).to(device=device, dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hybrid_attention_quadratic(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
f_q: torch.Tensor,
|
||||||
|
f_k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
window_factor: torch.Tensor,
|
||||||
|
linear_factor: torch.Tensor,
|
||||||
|
window_size: int,
|
||||||
|
kv_state: Optional[torch.Tensor] = None,
|
||||||
|
k_state: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mask_value: float = -1e8,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hybrid attention combining sliding window and linear attentions
|
||||||
|
"""
|
||||||
|
|
||||||
|
mask_window, mask_linear = get_masks(
|
||||||
|
window_size, q.shape[-2], k.shape[-2], q.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Sliding window (softmax attention)
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k.float()) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm = a_sm.masked_fill(~mask_window.bool(), mask_value)
|
||||||
|
# torch.softmax(a_sm, dim=-1), but we account for the max when combining
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factor * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 2. Under window (linear attention)
|
||||||
|
a_ln = torch.einsum("bhmd,bhnd->bhmn", f_q.float(), f_k.float())
|
||||||
|
a_ln = linear_factor * a_ln.masked_fill(~mask_linear.bool(), 0)
|
||||||
|
sum_ln = a_ln.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# 3. Combine
|
||||||
|
a = ((a_sm + a_ln) / (sum_sm + sum_ln)).to(q.dtype) # Save attention weights
|
||||||
|
# Allow outputs to also depend on prior kv_state and k_state
|
||||||
|
y = torch.einsum("bhmn,bhnd->bhmd", a_sm + a_ln, v.float())
|
||||||
|
if (
|
||||||
|
kv_state is not None and k_state is not None
|
||||||
|
): # Combine with prior kv_state and k_state
|
||||||
|
y += linear_factor * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln += (
|
||||||
|
linear_factor
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[..., None]
|
||||||
|
)
|
||||||
|
y = (y / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
return y, a # attention weights only for the last chunk
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------
|
||||||
|
# Attention layer class
|
||||||
|
# ---------------------
|
||||||
|
class LolcatsTKWindowAttention(LolcatsLinearAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
window_size: int = 64,
|
||||||
|
decode_window_size: Optional[int] = None,
|
||||||
|
affine_attention_factors: bool = False,
|
||||||
|
init_window_factor: float = 0,
|
||||||
|
train_window_factor: bool = True,
|
||||||
|
state_grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.decode_window_size = (
|
||||||
|
decode_window_size if decode_window_size is not None else window_size
|
||||||
|
)
|
||||||
|
self.window_kwargs = {"dimension": 2, "size": window_size, "step": 1}
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.attention_type = kwargs["attention_type"] # 'hedgehog_llama_window_tk'
|
||||||
|
# Determine how we compute attentions
|
||||||
|
self.quadratic_attention = hybrid_attention_quadratic
|
||||||
|
self.attention_type = kwargs[
|
||||||
|
"attention_type"
|
||||||
|
] # 'hedgehog_long_llama_window_tk'
|
||||||
|
# Learnable factor for combining attentions
|
||||||
|
self.affine_attention_factors = affine_attention_factors
|
||||||
|
device, dtype = self.q_proj.weight.device, self.q_proj.weight.dtype
|
||||||
|
if train_window_factor:
|
||||||
|
self.window_factors = nn.Parameter(
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.register_buffer(
|
||||||
|
"window_factors",
|
||||||
|
init_window_factor
|
||||||
|
* torch.ones(1, self.num_heads, 1, 1, device=device, dtype=dtype),
|
||||||
|
)
|
||||||
|
# Whether we use original flash attention 2 inference (use during attention transfer)
|
||||||
|
self.base_inference = False
|
||||||
|
self.state_grad_enabled = state_grad_enabled
|
||||||
|
self.window_factor = self.window_factors # legacy naming support
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(
|
||||||
|
k
|
||||||
|
) # Have to do after repeat for grouped-query attn if we use same fmap
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
# 1. Compute "ground-truth" attention output and weights
|
||||||
|
with torch.no_grad():
|
||||||
|
_y_true, a_true = softmax_attention(q, k, v)[:2]
|
||||||
|
y_true = (
|
||||||
|
_y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
|
||||||
|
# 2. Compute "predicted" attention outputs
|
||||||
|
# compute attn weights under sliding window
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = ((a_pred, a_true), (y_pred, _y_true))
|
||||||
|
else:
|
||||||
|
attn_weights = None
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
attn_weights = a_pred
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if (
|
||||||
|
f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training
|
||||||
|
): # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum(
|
||||||
|
"bhmd,bhnd->bhmn", q.float(), k_cache.float()
|
||||||
|
) * (k.shape[-1] ** -0.5)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum(
|
||||||
|
"bhld,bhnd->bhl", f_q.float(), f_k_state.float()
|
||||||
|
)[..., None]
|
||||||
|
)
|
||||||
|
y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_true, _ = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
self.layer_idx,
|
||||||
|
fmap_key_states=f_k,
|
||||||
|
accumulate_in_fp32=True,
|
||||||
|
)
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
return y_true, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionTKWindowCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a "KV state" and "K state"
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.kv_states: List[torch.Tensor] = []
|
||||||
|
self.k_states: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
# Account for sliding windows
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
def update(
|
||||||
|
self,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
|
cache_kwargs: Optional[Any] = None,
|
||||||
|
accumulate_in_fp32: bool = False,
|
||||||
|
fmap_key_states: Optional[torch.Tensor] = None, # should not be None
|
||||||
|
grad_enabled: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Update KV, K states; and KV cache during training
|
||||||
|
- For decoding, use `self.decode_kv_states` to keep track of KV states
|
||||||
|
up to sliding window terms
|
||||||
|
- For (chunked) training, use `self.kv_states` to keep track of KV states
|
||||||
|
up to end of sequence
|
||||||
|
- Likewise for `self.decode_k_states` and `self.k_states`
|
||||||
|
"""
|
||||||
|
if fmap_key_states is None:
|
||||||
|
raise ValueError("fmap_key_states should not be None")
|
||||||
|
|
||||||
|
if layer_idx is None:
|
||||||
|
raise ValueError("layer_idx should not be None")
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += key_states.shape[-2]
|
||||||
|
|
||||||
|
dtype = key_states.dtype
|
||||||
|
if accumulate_in_fp32:
|
||||||
|
# key_states = key_states.float()
|
||||||
|
fmap_key_states = fmap_key_states.float()
|
||||||
|
value_states = value_states.float()
|
||||||
|
|
||||||
|
# Decoding KV state (KV terms up to last window_size)
|
||||||
|
decode_kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, : -self.window_size],
|
||||||
|
value_states[:, :, : -self.window_size],
|
||||||
|
)
|
||||||
|
# KV state
|
||||||
|
kv_state = decode_kv_state + torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd",
|
||||||
|
fmap_key_states[:, :, -self.window_size :],
|
||||||
|
value_states[:, :, -self.window_size :],
|
||||||
|
)
|
||||||
|
# shape is b, h, 1, f; note the 1
|
||||||
|
decode_k_state = fmap_key_states[:, :, : -self.window_size].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
k_state = decode_k_state + fmap_key_states[:, :, -self.window_size :].sum(
|
||||||
|
dim=-2, keepdim=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the cache
|
||||||
|
if len(self.k_states) <= layer_idx: # Initializing kv and k states
|
||||||
|
self.kv_states.append(kv_state.to(dtype))
|
||||||
|
self.k_states.append(k_state.to(dtype))
|
||||||
|
|
||||||
|
self.decode_kv_states.append(decode_kv_state.to(dtype))
|
||||||
|
self.decode_k_states.append(decode_k_state.to(dtype))
|
||||||
|
|
||||||
|
self.k_cache.append(key_states[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(
|
||||||
|
value_states[:, :, -self.window_size :, :].to(dtype)
|
||||||
|
)
|
||||||
|
# self._seen_tokens_by_layer[layer_idx].append(key_states.shape[-2])
|
||||||
|
else:
|
||||||
|
# Update kv and k states recurrently
|
||||||
|
kv_state = (self.kv_states[layer_idx].to(kv_state.dtype) + kv_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
k_state = (self.k_states[layer_idx].to(kv_state.dtype) + k_state).to(
|
||||||
|
dtype
|
||||||
|
)
|
||||||
|
self.kv_states[layer_idx] = kv_state
|
||||||
|
self.k_states[layer_idx] = k_state
|
||||||
|
|
||||||
|
decode_kv_state = (
|
||||||
|
self.decode_kv_states[layer_idx].to(kv_state.dtype)
|
||||||
|
+ decode_kv_state
|
||||||
|
).to(dtype)
|
||||||
|
decode_k_state = (
|
||||||
|
self.decode_k_states[layer_idx].to(kv_state.dtype) + decode_k_state
|
||||||
|
).to(dtype)
|
||||||
|
self.decode_kv_states[layer_idx] = decode_kv_state
|
||||||
|
self.decode_k_states[layer_idx] = decode_k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = key_states[:, :, -self.window_size :, :]
|
||||||
|
self.v_cache[layer_idx] = value_states[:, :, -self.window_size :, :]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += key_states.shape[-2]
|
||||||
|
|
||||||
|
return self.kv_states[layer_idx], self.k_states[layer_idx]
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self,
|
||||||
|
keys: torch.Tensor,
|
||||||
|
values: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
feature_map_k: Callable,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the decoding KV and K states, and KV cache, during decodeing
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
|
||||||
|
if k_cache.shape[-2] < self.window_size: # build window-size cache
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache, keys], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache, values], dim=-2)
|
||||||
|
else:
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum(
|
||||||
|
"bhlf,bhld->bhfd", k_state.float(), v_state.float()
|
||||||
|
).to(
|
||||||
|
dtype
|
||||||
|
) # b, h, f, d
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat(
|
||||||
|
[k_cache[:, :, 1:, :], keys], dim=-2
|
||||||
|
)
|
||||||
|
self.v_cache[layer_idx] = torch.cat(
|
||||||
|
[v_cache[:, :, 1:, :], values], dim=-2
|
||||||
|
)
|
||||||
|
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += keys.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += keys.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
@@ -0,0 +1,221 @@
|
|||||||
|
"""
|
||||||
|
LoLCATs + ThunderKittens linear attention + sliding window for generation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .linear_attention import LinearAttentionState
|
||||||
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
|
LOG = logging.getLogger(
|
||||||
|
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||||
|
|
||||||
|
LOG.debug("Successfully imported ThunderKittens for TK window attention")
|
||||||
|
except ImportError:
|
||||||
|
LOG.debug("Failed to import ThunderKittens for TK window attention")
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsWindowAttentionTKGen(LolcatsTKWindowLongAttention):
|
||||||
|
def __init__(self, *args, window_size: int = 64, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.train_attention = False
|
||||||
|
self.base_inference = False
|
||||||
|
self.window_size = 64 # hard-coded support for TK kernel
|
||||||
|
self.decode_window_size = 64
|
||||||
|
|
||||||
|
b, h, l, d = 1, 32, 8192, 128
|
||||||
|
self.y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device="cuda")
|
||||||
|
self.kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device="cuda")
|
||||||
|
self.k_state = torch.zeros(b, h, d, dtype=torch.float32, device="cuda")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Any] = None, # “legacy” cache approach
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
assert (
|
||||||
|
past_key_value is not None
|
||||||
|
), "past_key_value must be provided for generation"
|
||||||
|
assert (
|
||||||
|
self.train_attention is False
|
||||||
|
), "train_attention is not supported for generation"
|
||||||
|
assert (
|
||||||
|
self.base_inference is False
|
||||||
|
), "base_inference is not supported for generation"
|
||||||
|
assert use_cache is True, "use_cache must be True for generation"
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
if q.shape[2] == 1 and kv_seq_len > 1: # Generating after prefill
|
||||||
|
f_q = self.feature_map_q(q)
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k
|
||||||
|
)
|
||||||
|
k_cache, v_cache, kv_state, k_state = _kv
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
|
||||||
|
# Softmax attention terms
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||||
|
k.shape[-1] ** -0.5
|
||||||
|
)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
# Combine with linear attention terms
|
||||||
|
y_true = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhld,bhdf->bhlf", f_q.float(), kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum("bhld,bhnd->bhl", f_q.float(), k_state.float())[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.y_true = (y_true / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Process prefill
|
||||||
|
# Use TK-implemented linear + terrace window attention
|
||||||
|
b, h, l, d = q.shape
|
||||||
|
device = q.device
|
||||||
|
# tk.hedgehog arguments
|
||||||
|
# y_true = torch.zeros(b, h, l, d, dtype=torch.bfloat16, device=device)
|
||||||
|
# kv_state = torch.zeros(b, h, d, d, dtype=torch.float32, device=device)
|
||||||
|
# k_state = torch.zeros(b, h, d, dtype=torch.float32, device=device)
|
||||||
|
betas = F.sigmoid(self.window_factors[0, :, 0, 0].to(dtype=torch.float32))
|
||||||
|
alphas = (
|
||||||
|
1 - betas
|
||||||
|
if self.affine_attention_factors
|
||||||
|
else torch.ones(betas.shape, dtype=torch.float32, device=device)
|
||||||
|
)
|
||||||
|
q_map = self.feature_map_q.mlp.layer
|
||||||
|
k_map = self.feature_map_k.mlp.layer
|
||||||
|
# Saves outputs to y_pred, k_state, kv_state, where we fuse:
|
||||||
|
# 1. f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||||
|
# 2. y_pred = attention(q, k, f_q, f_k, v) # b, h, l, d
|
||||||
|
# 3. kv_state = torch.einsum(‘bhlf,bhld->bhfd’,
|
||||||
|
# f_k[:, :, :-self.window_size],
|
||||||
|
# v[:, :, :-self.window_size]) # b, h, f, d
|
||||||
|
# 4. k_state = f_k[:, :, :-self.window_size].sum(dim=-2) # b, h, d
|
||||||
|
|
||||||
|
tk_window_hedgehog_attention(
|
||||||
|
q.contiguous(),
|
||||||
|
k.contiguous(),
|
||||||
|
v.contiguous(),
|
||||||
|
self.y_true,
|
||||||
|
self.k_state,
|
||||||
|
self.kv_state,
|
||||||
|
q_map,
|
||||||
|
k_map,
|
||||||
|
alphas,
|
||||||
|
betas,
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_value.update_with_kv(
|
||||||
|
self.kv_state, self.k_state.unsqueeze(-2), k, v, self.layer_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
y_true = self.y_true.transpose(1, 2).contiguous().view(b, l, self.hidden_size)
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
return y_true, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
class LinearAttentionTKWindowGenerationCache(LinearAttentionState):
|
||||||
|
"""
|
||||||
|
Class for `past_key_values`
|
||||||
|
-> Alternative to KV cache; here we only maintain a “KV state” and “K state”
|
||||||
|
-> Modified from transformers.cache_utils.DynamicCache (v4.36)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._seen_tokens = 0 # should be `self.seen_tokens` in Transformers v4.36
|
||||||
|
self._seen_tokens_by_layer: List[int] = []
|
||||||
|
self.window_size = window_size
|
||||||
|
|
||||||
|
self.decode_kv_states: List[torch.Tensor] = []
|
||||||
|
self.decode_k_states: List[torch.Tensor] = []
|
||||||
|
self.k_cache: List[torch.Tensor] = []
|
||||||
|
self.v_cache: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
def update_with_kv(
|
||||||
|
self,
|
||||||
|
kv_state: torch.Tensor,
|
||||||
|
k_state: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_idx: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the cache with new KV and K states
|
||||||
|
"""
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += k.shape[2]
|
||||||
|
self._seen_tokens_by_layer.append(k.shape[2])
|
||||||
|
|
||||||
|
# Initialize KV and K states
|
||||||
|
if len(self.decode_k_states) <= layer_idx:
|
||||||
|
self.decode_kv_states.append(kv_state)
|
||||||
|
self.decode_k_states.append(k_state)
|
||||||
|
else: # Update KV and K states
|
||||||
|
self.decode_kv_states[layer_idx] = (
|
||||||
|
self.decode_kv_states[layer_idx] + kv_state
|
||||||
|
)
|
||||||
|
self.decode_k_states[layer_idx] = self.decode_k_states[layer_idx] + k_state
|
||||||
|
|
||||||
|
self.k_cache.append(k[:, :, -self.window_size :, :])
|
||||||
|
self.v_cache.append(v[:, :, -self.window_size :, :])
|
||||||
|
|
||||||
|
def update_for_decoding(
|
||||||
|
self, k: torch.Tensor, v: torch.Tensor, layer_idx: int, feature_map_k: Callable
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the cache for decoding
|
||||||
|
"""
|
||||||
|
k_cache = self.k_cache[layer_idx]
|
||||||
|
v_cache = self.v_cache[layer_idx]
|
||||||
|
k_state = feature_map_k(k_cache[:, :, :1, :])
|
||||||
|
v_state = v_cache[:, :, :1, :]
|
||||||
|
kv_state = torch.einsum("bhlf,bhld->bhfd", k_state.float(), v_state.float()).to(
|
||||||
|
k.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decode_kv_states[layer_idx] += kv_state
|
||||||
|
self.decode_k_states[layer_idx] += k_state
|
||||||
|
|
||||||
|
self.k_cache[layer_idx] = torch.cat([k_cache[:, :, 1:, :], k], dim=-2)
|
||||||
|
self.v_cache[layer_idx] = torch.cat([v_cache[:, :, 1:, :], v], dim=-2)
|
||||||
|
if layer_idx == 0:
|
||||||
|
self._seen_tokens += k.shape[-2]
|
||||||
|
self._seen_tokens_by_layer[layer_idx] += k.shape[-2]
|
||||||
|
return (
|
||||||
|
self.k_cache[layer_idx],
|
||||||
|
self.v_cache[layer_idx],
|
||||||
|
self.decode_kv_states[layer_idx],
|
||||||
|
self.decode_k_states[layer_idx],
|
||||||
|
)
|
||||||
@@ -0,0 +1,305 @@
|
|||||||
|
"""
|
||||||
|
LoLCATs attention combining sliding window and linear attentions
|
||||||
|
- Using the TK "terracing" arrangement
|
||||||
|
- Training over long sequences with fixed memory with recurrent view
|
||||||
|
- During attention transfer, use Flash Attention to compute softmax attention outputs
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
- We first compute (softmax) attention over sliding windows
|
||||||
|
- We then compute standard linear attention to "fill in" the earlier parts
|
||||||
|
- We combine to model the entire sequence
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
|
from ..model.rotary import apply_rotary_pos_emb
|
||||||
|
from .linear_attention import softmax_attention
|
||||||
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
|
||||||
|
LOG = logging.getLogger(
|
||||||
|
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_tk_long"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LolcatsTKWindowLongAttention(LolcatsTKWindowAttention):
|
||||||
|
"""
|
||||||
|
Lolcats attention combining sliding window and linear attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, remove_base_attn=True, **kwargs):
|
||||||
|
# keep self.base_attn for Flash Attention inference
|
||||||
|
super().__init__(remove_base_attn=True, **kwargs)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass with the option to compute attention weights multiple ways
|
||||||
|
if self.train_attention is True
|
||||||
|
-> Consistent with HuggingFace Transformers for easy use with their pretrained models
|
||||||
|
"""
|
||||||
|
b, l, _ = hidden_states.size()
|
||||||
|
if self.train_attention and self.base_inference:
|
||||||
|
with torch.no_grad():
|
||||||
|
# LOG.debug(hidden_states.shape)
|
||||||
|
_y_true = flash_attention_2(
|
||||||
|
self, # self.base_attn,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=None,
|
||||||
|
output_attentions=False,
|
||||||
|
# output_hidden_states=False,
|
||||||
|
use_cache=False,
|
||||||
|
)[0]
|
||||||
|
# _y_true.shape is (batch_size, seq_len, num_heads, head_dim)
|
||||||
|
y_true = _y_true.reshape(b, l, -1).contiguous()
|
||||||
|
y_true = self.o_proj(y_true)
|
||||||
|
layer_io = (hidden_states, _y_true) # hack
|
||||||
|
# layer_io = (hidden_states.cpu(), _y_true.cpu()) # hack
|
||||||
|
return y_true, layer_io, None
|
||||||
|
|
||||||
|
q, k, v, kv_seq_len = self.process_qkv(
|
||||||
|
hidden_states, attention_mask, position_ids, past_key_value
|
||||||
|
)
|
||||||
|
f_q, f_k = self.feature_map_q(q), self.feature_map_k(k)
|
||||||
|
|
||||||
|
# attention_mask = None # For now this is always True
|
||||||
|
if past_key_value is None: # Regular training
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = 1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
past_key_value.window_size = self.decode_window_size
|
||||||
|
if f_q.shape[2] == 1 and kv_seq_len > 1 and not self.training: # Generating
|
||||||
|
assert use_cache is True
|
||||||
|
_kv = past_key_value.update_for_decoding(
|
||||||
|
k, v, self.layer_idx, self.feature_map_k, dtype=q.dtype
|
||||||
|
)
|
||||||
|
k_cache, v_cache, f_kv_state, f_k_state = _kv
|
||||||
|
|
||||||
|
# Sliding window + linear attention decode
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
|
||||||
|
a_sm = torch.einsum("bhmd,bhnd->bhmn", q.float(), k_cache.float()) * (
|
||||||
|
k.shape[-1] ** -0.5
|
||||||
|
)
|
||||||
|
# a_sm = torch.softmax(a_sm, dim=-1)
|
||||||
|
a_sm_max = torch.amax(a_sm, dim=-1, keepdim=True)
|
||||||
|
a_sm = window_factors * torch.exp(a_sm - a_sm_max)
|
||||||
|
sum_sm = a_sm.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
y_pred = torch.einsum(
|
||||||
|
"bhmn,bhnd->bhmd", a_sm, v_cache.float()
|
||||||
|
) + linear_factors * torch.einsum(
|
||||||
|
"bhlf,bhfd->bhld", f_q.float(), f_kv_state.float()
|
||||||
|
)
|
||||||
|
sum_ln = (
|
||||||
|
linear_factors
|
||||||
|
* torch.einsum("bhlf,bhnf->bhl", f_q.float(), f_k_state.float())[
|
||||||
|
..., None
|
||||||
|
]
|
||||||
|
)
|
||||||
|
y_pred = (y_pred / (sum_sm + sum_ln)).to(q.dtype)
|
||||||
|
|
||||||
|
else: # Stateful training
|
||||||
|
if (
|
||||||
|
self.state_grad_enabled
|
||||||
|
and self.layer_idx == 0
|
||||||
|
and position_ids is not None
|
||||||
|
):
|
||||||
|
LOG.debug(
|
||||||
|
f"\n position_ids: [{position_ids[0, 0]}, {position_ids[0, -1]}]"
|
||||||
|
)
|
||||||
|
LOG.debug(
|
||||||
|
f"q.shape: {q.shape}, k.shape: {k.shape}, v.shape: {v.shape}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
kv_state = past_key_value.kv_states[self.layer_idx]
|
||||||
|
k_state = past_key_value.k_states[self.layer_idx]
|
||||||
|
except IndexError:
|
||||||
|
kv_state, k_state = None, None
|
||||||
|
window_factors = F.sigmoid(self.window_factors)
|
||||||
|
linear_factors = (
|
||||||
|
1 - window_factors if self.affine_attention_factors else 1
|
||||||
|
)
|
||||||
|
y_pred, a_pred = self.quadratic_attention(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
f_q,
|
||||||
|
f_k,
|
||||||
|
v,
|
||||||
|
window_factors,
|
||||||
|
linear_factors,
|
||||||
|
window_size=self.window_size,
|
||||||
|
kv_state=kv_state,
|
||||||
|
k_state=k_state,
|
||||||
|
)
|
||||||
|
# Save and update KV cache and states
|
||||||
|
# past_key_value.update(k, v.detach(), self.layer_idx,
|
||||||
|
# fmap_key_states=f_k.detach(),
|
||||||
|
# accumulate_in_fp32=True)
|
||||||
|
past_key_value.update(
|
||||||
|
k, v, self.layer_idx, fmap_key_states=f_k, accumulate_in_fp32=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Concatenate heads and apply output projection
|
||||||
|
_y_pred = y_pred.transpose(1, 2).contiguous()
|
||||||
|
y_pred = self.o_proj(_y_pred.view(b, l, self.hidden_size))
|
||||||
|
|
||||||
|
if self.train_attention:
|
||||||
|
with torch.no_grad():
|
||||||
|
a_true = softmax_attention(q, k, None, causal=True)[1]
|
||||||
|
attn_weights = (_y_pred, (a_pred, a_true))
|
||||||
|
else:
|
||||||
|
attn_weights = _y_pred # flash_attn outputs are shape (b, l, h, d)
|
||||||
|
return y_pred, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------
|
||||||
|
# Flash Attention 2
|
||||||
|
# -----------------
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attention_2(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Wrapper for LlamaFlashAttention2
|
||||||
|
Copied and modified from HF Transformers v4.36 and v4.43 implementations
|
||||||
|
- (4.43) https://github.com/huggingface/transformers/blob/868d36d29ec132deeaaf8571b25b6a1b911d0145/src/transformers/models/llama/modeling_llama.py#L402
|
||||||
|
- (4.36) https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L456
|
||||||
|
"""
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
# Flash attention requires the input to have the shape
|
||||||
|
# batch_size x seq_length x head_dim x hidden_dim
|
||||||
|
# therefore we just need to keep the original shape
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
try: # As in Transformers v4.36
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(key_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
except Exception: # As in Transformers v4.39
|
||||||
|
cos, sin = self.rotary_emb(key_states, position_ids)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
||||||
|
# to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
dropout_rate = self.attention_dropout if self.training else 0.0
|
||||||
|
|
||||||
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
|
# cast them back in the correct dtype just to be sure everything works as expected.
|
||||||
|
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
||||||
|
# in fp32. (LlamaRMSNorm handles it correctly)
|
||||||
|
|
||||||
|
input_dtype = query_states.dtype
|
||||||
|
if input_dtype == torch.float32:
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
target_dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
# Handle the case where the model is quantized
|
||||||
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
|
target_dtype = self.config._pre_quantization_dtype
|
||||||
|
else:
|
||||||
|
target_dtype = self.q_proj.weight.dtype
|
||||||
|
|
||||||
|
LOG.debug(
|
||||||
|
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||||
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||||
|
f" {target_dtype}."
|
||||||
|
)
|
||||||
|
|
||||||
|
query_states = query_states.to(target_dtype)
|
||||||
|
key_states = key_states.to(target_dtype)
|
||||||
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
|
if getattr(self, "_flash_attention_forward", False):
|
||||||
|
attn_output = self._flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=dropout_rate,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attn_output = _flash_attention_forward(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
value_states,
|
||||||
|
attention_mask,
|
||||||
|
q_len,
|
||||||
|
dropout=0, # dropout_rate,
|
||||||
|
sliding_window=getattr(self, "sliding_window", None),
|
||||||
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
return attn_output, past_key_value
|
||||||
34
src/axolotl/integrations/lolcats/linear_attention/utils.py
Normal file
34
src/axolotl/integrations/lolcats/linear_attention/utils.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""
|
||||||
|
Shared attention helpers
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
||||||
|
The hidden states go from:
|
||||||
|
(batch, num_key_value_heads, seqlen, head_dim) to
|
||||||
|
(batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||||
|
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||||
|
)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_attention(
|
||||||
|
qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Apply attention mask (e.g., for padding)
|
||||||
|
"""
|
||||||
|
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
|
||||||
|
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
|
||||||
|
else:
|
||||||
|
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)
|
||||||
222
src/axolotl/integrations/lolcats/linearize_attention.py
Normal file
222
src/axolotl/integrations/lolcats/linearize_attention.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
Convert attention to linear attention
|
||||||
|
|
||||||
|
Adapted from: https://github.com/HazyResearch/lolcats/blob/main/src/model/convert_model.py
|
||||||
|
|
||||||
|
@misc{zhang2024lolcatslowranklinearizinglarge,
|
||||||
|
title={LoLCATs: On Low-Rank Linearizing of Large Language Models},
|
||||||
|
author={Michael Zhang and Simran Arora and Rahul Chalamala and Alan Wu and Benjamin Spector and Aaryan Singhal and Krithik Ramesh and Christopher Ré},
|
||||||
|
year={2024},
|
||||||
|
eLOG.info={2410.10254},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.LG},
|
||||||
|
url={https://arxiv.org/abs/2410.10254},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.integrations.lolcats.linearize_attention")
|
||||||
|
|
||||||
|
|
||||||
|
def convert_attention(
|
||||||
|
model: nn.Module,
|
||||||
|
attention_config: DictDefault,
|
||||||
|
train_attention: bool = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Call to convert all attention layers
|
||||||
|
"""
|
||||||
|
softmax_attns = []
|
||||||
|
if "softmax_attentions" in attention_config:
|
||||||
|
softmax_attns = attention_config["softmax_attentions"]
|
||||||
|
if attention_config.attention_type != "softmax":
|
||||||
|
layers = traverse_layers(model)
|
||||||
|
for layer_idx, layer in enumerate(
|
||||||
|
tqdm(layers, desc="Converting attentions...")
|
||||||
|
):
|
||||||
|
if layer_idx not in softmax_attns:
|
||||||
|
layer.self_attn = convert_llama_attention(
|
||||||
|
layer,
|
||||||
|
attention_config,
|
||||||
|
layers,
|
||||||
|
train_attention,
|
||||||
|
remove_base_attn,
|
||||||
|
)
|
||||||
|
layer.self_attn.converted = True
|
||||||
|
else: # Freeze any preserved softmax attention layers
|
||||||
|
for p in layer.parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
f"-> attention_config.attention_type is {attention_config.attention_type}; not converting attentions"
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def toggle_attention(llama_model: nn.Module, train: bool = False):
|
||||||
|
"""
|
||||||
|
Make attentions trainable if train is True
|
||||||
|
-> Set train_attention = False when finetuning
|
||||||
|
"""
|
||||||
|
for layer in traverse_layers(llama_model):
|
||||||
|
layer.self_attn.train_attention = train
|
||||||
|
return llama_model
|
||||||
|
|
||||||
|
|
||||||
|
def remove_base_attention(llama_model: nn.Module):
|
||||||
|
"""
|
||||||
|
Remove teacher attention after distillation (if we keep it)
|
||||||
|
"""
|
||||||
|
for layer in traverse_layers(llama_model):
|
||||||
|
if getattr(layer.self_attn, "base_attn", False):
|
||||||
|
del layer.self_attn.base_attn
|
||||||
|
return llama_model
|
||||||
|
|
||||||
|
|
||||||
|
def traverse_layers(model: nn.Module, verbose: bool = False):
|
||||||
|
"""
|
||||||
|
Return list of model layers
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
layers = model.model.layers
|
||||||
|
if verbose:
|
||||||
|
LOG.info("-> Loading from model.model.layers")
|
||||||
|
except AttributeError as e: # if base model
|
||||||
|
if verbose:
|
||||||
|
LOG.info(e)
|
||||||
|
try:
|
||||||
|
layers = model.layers
|
||||||
|
if verbose:
|
||||||
|
LOG.info("-> Loading from model.layers")
|
||||||
|
except AttributeError as e1: # If we make a PEFT model
|
||||||
|
if verbose:
|
||||||
|
LOG.info(e1)
|
||||||
|
layers = model.base_model.model.model.layers
|
||||||
|
if verbose:
|
||||||
|
LOG.info("-> Loading from model.base_model.model.model.layers")
|
||||||
|
return layers
|
||||||
|
|
||||||
|
|
||||||
|
def convert_llama_attention(
|
||||||
|
layer: nn.Module,
|
||||||
|
attention_config: DictDefault,
|
||||||
|
layers: list[nn.Module], # list of layers
|
||||||
|
train_attention: bool = False,
|
||||||
|
remove_base_attn: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converts a single layer's attention layer as specified by attention_config
|
||||||
|
"""
|
||||||
|
return get_attention(**attention_config)(
|
||||||
|
base_attn=layer.self_attn,
|
||||||
|
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||||
|
max_layer_idx=len(layers) - 1,
|
||||||
|
train_attention=train_attention,
|
||||||
|
remove_base_attn=remove_base_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention(attention_type: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Get the linear attention class; either purely linear or linear with sliding window
|
||||||
|
-> 'linear' == 'lolcats_llama'
|
||||||
|
-> 'linear and sliding_window' == 'lolcats_llama_window_*'
|
||||||
|
"""
|
||||||
|
kwargs["attention_type"] = attention_type
|
||||||
|
|
||||||
|
if attention_type == "lolcats_llama":
|
||||||
|
from .linear_attention import LolcatsLinearAttention
|
||||||
|
|
||||||
|
return partial(LolcatsLinearAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_llama_window_tk":
|
||||||
|
from .linear_attention import LolcatsTKWindowAttention
|
||||||
|
|
||||||
|
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_llama_window_sw":
|
||||||
|
from .linear_attention import LolcatsSlidingWindowAttention
|
||||||
|
|
||||||
|
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||||
|
from .linear_attention.linear_window_attention_sw_linear import (
|
||||||
|
LolcatsLinearSlidingWindowAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
|
# Experimental chunked linear attentions below
|
||||||
|
elif attention_type == "lolcats_long_llama_window_tk":
|
||||||
|
from .linear_attention import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
|
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
|
elif attention_type == "lolcats_long_llama_window_sw":
|
||||||
|
from .linear_attention import LolcatsSlidingWindowLongAttention
|
||||||
|
|
||||||
|
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
|
# TK generation build (requires Thunderkittens)
|
||||||
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
|
from .linear_attention import LolcatsWindowAttentionTKGen
|
||||||
|
|
||||||
|
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
LOG.info(f"-> attention_type {attention_type} not handled... returning None")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
||||||
|
"""
|
||||||
|
Determine how we store past keys and values when generating
|
||||||
|
"""
|
||||||
|
if attention_type is None:
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||||
|
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||||
|
from .linear_attention import LinearAttentionTKWindowGenerationCache
|
||||||
|
|
||||||
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
|
elif "llama_window_tk" in attention_type:
|
||||||
|
from .linear_attention import LinearAttentionTKWindowCache
|
||||||
|
|
||||||
|
return LinearAttentionTKWindowCache()
|
||||||
|
|
||||||
|
elif "llama_window_sw" in attention_type:
|
||||||
|
from .linear_attention import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
|
elif "llama_window_sw_linear" in attention_type:
|
||||||
|
from .linear_attention import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
|
# TK generation build (requires Thunderkittens)
|
||||||
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
|
from .linear_attention.linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
|
elif "softmax" in attention_type:
|
||||||
|
return past_key_values
|
||||||
|
|
||||||
|
else:
|
||||||
|
from .linear_attention import LinearAttentionState
|
||||||
|
|
||||||
|
return LinearAttentionState()
|
||||||
336
src/axolotl/integrations/lolcats/model/feature_map.py
Normal file
336
src/axolotl/integrations/lolcats/model/feature_map.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""
|
||||||
|
Learnable linear attention feature map classes and functions
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map final activation for linear attention
|
||||||
|
"""
|
||||||
|
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map final activation for linear attention
|
||||||
|
"""
|
||||||
|
if name == "softmax_dim" and fullspace:
|
||||||
|
return SoftmaxDim(**kwargs)
|
||||||
|
elif name == "softmax_dim" and not fullspace:
|
||||||
|
return SoftmaxDimHalfspace(**kwargs)
|
||||||
|
elif name == "exp_dim" and fullspace:
|
||||||
|
return Exp(**kwargs)
|
||||||
|
elif name == "exp_dim" and not fullspace:
|
||||||
|
return ExpHalfspace(**kwargs)
|
||||||
|
elif name == "pos_elu":
|
||||||
|
return PosELU(**kwargs)
|
||||||
|
elif name == "relu":
|
||||||
|
return ReLU(**kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def init_learned_kernel(name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map MLP for linear attention
|
||||||
|
"""
|
||||||
|
if name == "untied_head_einsum":
|
||||||
|
return FeatureMapMLP(**kwargs)
|
||||||
|
elif name == "untied_head_adapter":
|
||||||
|
return FeatureMapAdapter(**kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMap(nn.Module):
|
||||||
|
"""
|
||||||
|
Final 'activation' of feature map. Can probably be combined with
|
||||||
|
`FeatureMapMLP` below
|
||||||
|
|
||||||
|
Full feature map is like f(xW + b)
|
||||||
|
-> This is the `f` part
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_name: str,
|
||||||
|
head_dim_idx: int = -1,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mlp: Optional[nn.Module] = None,
|
||||||
|
fullspace: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim_idx = head_dim_idx
|
||||||
|
self.eps = eps
|
||||||
|
self.mlp = mlp if mlp is not None else nn.Identity()
|
||||||
|
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
||||||
|
|
||||||
|
def q_map(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Use for inference in case q and k feature maps differ
|
||||||
|
"""
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def k_map(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Use for inference in case q and k feature maps differ
|
||||||
|
"""
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Feature map activations
|
||||||
|
# -----------------------
|
||||||
|
class FeatureMapAct(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for feature map activations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps: float = 1e-12):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PosELU(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return (1 + F.elu(x)).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
ReLU activation as in https://arxiv.org/abs/2103.13076
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return F.relu(x).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxDim(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return torch.cat(
|
||||||
|
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
||||||
|
).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxDimHalfspace(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Exp(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||||
|
x_min = torch.amin(x, dim=-1, keepdim=True)
|
||||||
|
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
||||||
|
min=self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpHalfspace(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||||
|
return torch.exp(x - x_max).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------
|
||||||
|
# Feature map MLPs
|
||||||
|
# ----------------
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapMLP(nn.Module):
|
||||||
|
"""
|
||||||
|
Learnable MLP in feature map.
|
||||||
|
|
||||||
|
Full feature map is like f(xW + b)
|
||||||
|
-> This is the `W` and (optional) `b` part
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int, # input dim
|
||||||
|
feature_dim: int, # output dim
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
skip_connection: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
zero_init: bool = False,
|
||||||
|
normal_init: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
self.skip_connection = skip_connection
|
||||||
|
self.bias = bias
|
||||||
|
self.zero_init = zero_init
|
||||||
|
self.normal_init = normal_init
|
||||||
|
self.init_weights_()
|
||||||
|
|
||||||
|
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
||||||
|
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
||||||
|
|
||||||
|
if self.normal_init:
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.normal_(self.layer)
|
||||||
|
|
||||||
|
if self.skip_connection:
|
||||||
|
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
||||||
|
assert self.head_dim == self.feature_dim, assertion_fail
|
||||||
|
|
||||||
|
def init_weights_(self):
|
||||||
|
"""
|
||||||
|
Initialize (W)eights and (b)iases
|
||||||
|
"""
|
||||||
|
self.layer = nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(self.num_heads, self.head_dim, self.feature_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.layer)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
self.bias = nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(1, self.num_heads, 1, 1), # self.feature_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.bias)
|
||||||
|
else:
|
||||||
|
self.bias = 0.0 # hack
|
||||||
|
|
||||||
|
def zero_init_with_skip_(self):
|
||||||
|
"""
|
||||||
|
Initialize weights to zero matrix if skip connection
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.zeros_(self.layer)
|
||||||
|
|
||||||
|
def zero_init_(self):
|
||||||
|
"""
|
||||||
|
Initialize weights to identity matrix if no skip connection
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(self.layer.shape[0]):
|
||||||
|
try:
|
||||||
|
nn.init.eye_(self.layer[i])
|
||||||
|
except RuntimeError:
|
||||||
|
with torch.no_grad():
|
||||||
|
dtype = self.layer[i].dtype
|
||||||
|
weight = torch.eye(
|
||||||
|
*self.layer[i].shape,
|
||||||
|
requires_grad=self.layer[i].requires_grad,
|
||||||
|
device=self.layer[i].device,
|
||||||
|
)
|
||||||
|
self.layer[i] = weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
||||||
|
return x + _x if self.skip_connection else _x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapAdapter(FeatureMapMLP):
|
||||||
|
"""
|
||||||
|
Learnable Feature map with bottleneck adapter
|
||||||
|
as in https://arxiv.org/abs/1902.00751
|
||||||
|
|
||||||
|
We don't use but could be fun to try
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int, *args, **kwargs):
|
||||||
|
kwargs["skip_connection"] = True
|
||||||
|
kwargs["bias"] = True
|
||||||
|
kwargs["zero_init"] = True
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_weights_(self):
|
||||||
|
"""
|
||||||
|
Initialize (W)eights and (b)iases
|
||||||
|
"""
|
||||||
|
kwargs = {"dtype": self.dtype, "device": self.device}
|
||||||
|
self.layer0 = nn.Parameter(
|
||||||
|
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
||||||
|
)
|
||||||
|
self.layer1 = nn.Parameter(
|
||||||
|
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.layer0)
|
||||||
|
nn.init.kaiming_uniform_(self.layer1)
|
||||||
|
|
||||||
|
self.bias0 = nn.Parameter(
|
||||||
|
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
||||||
|
)
|
||||||
|
self.bias1 = nn.Parameter(
|
||||||
|
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.bias0)
|
||||||
|
nn.init.kaiming_uniform_(self.bias1)
|
||||||
|
|
||||||
|
def zero_init_with_skip_(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.zeros_(self.layer0)
|
||||||
|
nn.init.zeros_(self.layer1)
|
||||||
|
nn.init.zeros_(self.bias0)
|
||||||
|
nn.init.zeros_(self.bias1)
|
||||||
|
|
||||||
|
def zero_init_(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
-> Down-project, apply nonlinearity, up-project; add skip connection
|
||||||
|
"""
|
||||||
|
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
||||||
|
_x = F.relu(_x)
|
||||||
|
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
||||||
|
return x + _x if self.skip_connection else _x
|
||||||
204
src/axolotl/integrations/lolcats/model/rotary.py
Normal file
204
src/axolotl/integrations/lolcats/model/rotary.py
Normal file
@@ -0,0 +1,204 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Rotary embeddings. Same as usual for Transformer models.
|
||||||
|
|
||||||
|
Note these are modified from HF Transformers v4.36, from:
|
||||||
|
- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py
|
||||||
|
- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotary_embeddings(
|
||||||
|
rope_scaling_type: Optional[str] = None,
|
||||||
|
head_dim: int = 128,
|
||||||
|
max_position_embeddings: int = 4096,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
rope_scaling_factor: float = 1.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
"""Return rotary embedding object"""
|
||||||
|
if rope_scaling_type is None:
|
||||||
|
return RotaryEmbedding(
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif rope_scaling_type == "linear":
|
||||||
|
return LinearScalingRotaryEmbedding(
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
scaling_factor=rope_scaling_factor,
|
||||||
|
base=rope_theta,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
elif rope_scaling_type == "dynamic":
|
||||||
|
return DynamicNTKScalingRotaryEmbedding(
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
scaling_factor=rope_scaling_factor,
|
||||||
|
base=rope_theta,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||||
|
def rotate_half(x):
|
||||||
|
"""Rotates half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
||||||
|
"""Applies Rotary Position Embedding to the query and key tensors."""
|
||||||
|
if position_ids is not None:
|
||||||
|
cos, sin = cos[position_ids], sin[position_ids]
|
||||||
|
cos = cos.unsqueeze(unsqueeze_dim)
|
||||||
|
sin = sin.unsqueeze(unsqueeze_dim)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
||||||
|
class RotaryEmbedding(nn.Module):
|
||||||
|
"""Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864"""
|
||||||
|
|
||||||
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
# Build here to make `torch.jit.trace` work.
|
||||||
|
self._set_cos_sin_cache(
|
||||||
|
seq_len=max_position_embeddings,
|
||||||
|
device=self.inv_freq.device,
|
||||||
|
dtype=torch.get_default_dtype(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(
|
||||||
|
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
def forward(self, x, seq_len=None):
|
||||||
|
"""
|
||||||
|
Compute rotary embeddings
|
||||||
|
"""
|
||||||
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
if seq_len > self.max_seq_len_cached:
|
||||||
|
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
||||||
|
|
||||||
|
return (
|
||||||
|
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
||||||
|
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers/models/llama/modeling_llama.py at v4.36
|
||||||
|
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
device=None,
|
||||||
|
scaling_factor=1.0,
|
||||||
|
):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
t = torch.arange(
|
||||||
|
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
t = t / self.scaling_factor
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from transformers/models/llama/modeling_llama.py at v4.36
|
||||||
|
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||||
|
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
max_position_embeddings=2048,
|
||||||
|
base=10000,
|
||||||
|
device=None,
|
||||||
|
scaling_factor=1.0,
|
||||||
|
):
|
||||||
|
self.scaling_factor = scaling_factor
|
||||||
|
super().__init__(dim, max_position_embeddings, base, device)
|
||||||
|
|
||||||
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
|
self.max_seq_len_cached = seq_len
|
||||||
|
|
||||||
|
if seq_len > self.max_position_embeddings:
|
||||||
|
base = self.base * (
|
||||||
|
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
||||||
|
- (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
inv_freq = 1.0 / (
|
||||||
|
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
||||||
|
)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
t = torch.arange(
|
||||||
|
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
freqs = torch.outer(t, self.inv_freq)
|
||||||
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
||||||
|
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
||||||
Reference in New Issue
Block a user