Compare commits
10 Commits
2d5f692fc0
...
feat/linea
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
13d458d0ae | ||
|
|
ebd406af1d | ||
|
|
caa49a9d7d | ||
|
|
c15ea6b956 | ||
|
|
578fa764c8 | ||
|
|
0e6efaa10c | ||
|
|
c4cb622590 | ||
|
|
0f82bd2d18 | ||
|
|
49746b184f | ||
|
|
9e1c4de13c |
@@ -49,12 +49,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
|||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
p.requires_grad = False
|
p.requires_grad = False
|
||||||
|
|
||||||
# load config
|
|
||||||
base_config = load_model_config(cfg)
|
|
||||||
|
|
||||||
# convert to linear llama
|
# convert to linear llama
|
||||||
linear_llama_config = LinearLlamaConfig.from_llama(
|
linear_llama_config = LinearLlamaConfig.from_llama(
|
||||||
base_config, cfg.attention_config
|
model.config, cfg.attention_config
|
||||||
)
|
)
|
||||||
model = LinearLlamaForCausalLM.from_llama(
|
model = LinearLlamaForCausalLM.from_llama(
|
||||||
model, config=linear_llama_config, train_attention=True
|
model, config=linear_llama_config, train_attention=True
|
||||||
|
|||||||
@@ -4,7 +4,17 @@ https://github.com/HazyResearch/lolcats/
|
|||||||
|
|
||||||
### Usage
|
### Usage
|
||||||
|
|
||||||
TODO: Add instruction to install `causal_dot_product`.
|
Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd src/axolotl/integrations/lolcats/linear_llama/csrc
|
||||||
|
|
||||||
|
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
|
||||||
|
# nano setup.py
|
||||||
|
|
||||||
|
# Build the CUDA kernel
|
||||||
|
python setup.py install
|
||||||
|
```
|
||||||
|
|
||||||
Step 1:
|
Step 1:
|
||||||
|
|
||||||
@@ -15,7 +25,9 @@ plugins:
|
|||||||
linearize: true
|
linearize: true
|
||||||
```
|
```
|
||||||
|
|
||||||
Step 2: Remove the config above and finetune with lora with below possible targets.
|
Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
|
||||||
|
|
||||||
|
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||||
@@ -24,3 +36,9 @@ lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
|
|||||||
# to allow this config to work with lora
|
# to allow this config to work with lora
|
||||||
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
|
||||||
|
|
||||||
|
Step 3: Run inference on the finetuned model
|
||||||
|
|
||||||
|
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`
|
||||||
|
|||||||
@@ -44,4 +44,4 @@ class LinearAttentionArgs(BaseModel):
|
|||||||
|
|
||||||
attention_config: AttentionConfig
|
attention_config: AttentionConfig
|
||||||
|
|
||||||
linearize: bool
|
linearize: Optional[bool] = False
|
||||||
|
|||||||
@@ -1,21 +0,0 @@
|
|||||||
"""
|
|
||||||
Linear and linear attention + sliding window classes
|
|
||||||
"""
|
|
||||||
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
|
|
||||||
from .linear_window_attention_sw import (
|
|
||||||
LinearAttentionSlidingWindowCache,
|
|
||||||
LolcatsSlidingWindowAttention,
|
|
||||||
)
|
|
||||||
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
|
|
||||||
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
|
||||||
from .linear_window_attention_tk import (
|
|
||||||
LinearAttentionTKWindowCache,
|
|
||||||
LolcatsTKWindowAttention,
|
|
||||||
)
|
|
||||||
from .linear_window_attention_tk_gen import (
|
|
||||||
LinearAttentionTKWindowGenerationCache,
|
|
||||||
LolcatsWindowAttentionTKGen,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Experimental chunk linear attentions
|
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""
|
|
||||||
Shared attention helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
|
|
||||||
The hidden states go from:
|
|
||||||
(batch, num_key_value_heads, seqlen, head_dim) to
|
|
||||||
(batch, num_attention_heads, seqlen, head_dim)
|
|
||||||
"""
|
|
||||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
|
||||||
if n_rep == 1:
|
|
||||||
return hidden_states
|
|
||||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
|
||||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
|
||||||
)
|
|
||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
|
||||||
|
|
||||||
|
|
||||||
def mask_attention(
|
|
||||||
qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Apply attention mask (e.g., for padding)
|
|
||||||
"""
|
|
||||||
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
|
|
||||||
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
|
|
||||||
else:
|
|
||||||
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)
|
|
||||||
@@ -64,6 +64,13 @@ class LinearLlamaConfig(LlamaConfig):
|
|||||||
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
def __init__(self, attention_config: Optional[dict] = None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
# Set auto_map
|
||||||
|
self.auto_map = {
|
||||||
|
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
|
||||||
|
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
|
||||||
|
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
|
||||||
|
}
|
||||||
|
|
||||||
# Set default attention config if none provided
|
# Set default attention config if none provided
|
||||||
self.attention_config = attention_config or {"attention_type": "softmax"}
|
self.attention_config = attention_config or {"attention_type": "softmax"}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers.cache_utils import Cache
|
from transformers.cache_utils import Cache
|
||||||
|
|
||||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
@@ -15,9 +16,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
fast_causal_dot_product = None
|
fast_causal_dot_product = None
|
||||||
|
|
||||||
from ..model.feature_map import init_feature_map, init_learned_kernel
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||||
from ..model.rotary import apply_rotary_pos_emb
|
|
||||||
from .utils import repeat_kv
|
|
||||||
|
|
||||||
# -------------------
|
# -------------------
|
||||||
# Attention functions
|
# Attention functions
|
||||||
@@ -366,7 +365,7 @@ class LolcatsLinearAttention(nn.Module):
|
|||||||
..., None
|
..., None
|
||||||
] # b, 1, k_len, 1
|
] # b, 1, k_len, 1
|
||||||
else:
|
else:
|
||||||
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1
|
lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
|
||||||
k = k.masked_fill(~lin_attn_mask, 0)
|
k = k.masked_fill(~lin_attn_mask, 0)
|
||||||
|
|
||||||
if past_key_value is not None: # Initialize states
|
if past_key_value is not None: # Initialize states
|
||||||
@@ -523,3 +522,335 @@ class LinearAttentionState(Cache):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Reordering cache not implemented for LinearAttentionState"
|
"Reordering cache not implemented for LinearAttentionState"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -------------------
|
||||||
|
# feature map functions
|
||||||
|
# -------------------
|
||||||
|
|
||||||
|
|
||||||
|
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map final activation for linear attention
|
||||||
|
"""
|
||||||
|
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map final activation for linear attention
|
||||||
|
"""
|
||||||
|
if name == "softmax_dim" and fullspace:
|
||||||
|
return SoftmaxDim(**kwargs)
|
||||||
|
elif name == "softmax_dim" and not fullspace:
|
||||||
|
return SoftmaxDimHalfspace(**kwargs)
|
||||||
|
elif name == "exp_dim" and fullspace:
|
||||||
|
return Exp(**kwargs)
|
||||||
|
elif name == "exp_dim" and not fullspace:
|
||||||
|
return ExpHalfspace(**kwargs)
|
||||||
|
elif name == "pos_elu":
|
||||||
|
return PosELU(**kwargs)
|
||||||
|
elif name == "relu":
|
||||||
|
return ReLU(**kwargs)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def init_learned_kernel(name: str, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize feature map MLP for linear attention
|
||||||
|
"""
|
||||||
|
if name == "untied_head_einsum":
|
||||||
|
return FeatureMapMLP(**kwargs)
|
||||||
|
elif name == "untied_head_adapter":
|
||||||
|
return FeatureMapAdapter(**kwargs)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMap(nn.Module):
|
||||||
|
"""
|
||||||
|
Final 'activation' of feature map. Can probably be combined with
|
||||||
|
`FeatureMapMLP` below
|
||||||
|
|
||||||
|
Full feature map is like f(xW + b)
|
||||||
|
-> This is the `f` part
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
activation_name: str,
|
||||||
|
head_dim_idx: int = -1,
|
||||||
|
eps: float = 1e-12,
|
||||||
|
mlp: Optional[nn.Module] = None,
|
||||||
|
fullspace: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.head_dim_idx = head_dim_idx
|
||||||
|
self.eps = eps
|
||||||
|
self.mlp = mlp if mlp is not None else nn.Identity()
|
||||||
|
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
||||||
|
|
||||||
|
def q_map(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Use for inference in case q and k feature maps differ
|
||||||
|
"""
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
def k_map(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Use for inference in case q and k feature maps differ
|
||||||
|
"""
|
||||||
|
return self.forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Feature map activations
|
||||||
|
# -----------------------
|
||||||
|
class FeatureMapAct(nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for feature map activations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, eps: float = 1e-12):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class PosELU(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return (1 + F.elu(x)).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class ReLU(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
ReLU activation as in https://arxiv.org/abs/2103.13076
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return F.relu(x).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxDim(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return torch.cat(
|
||||||
|
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
||||||
|
).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class SoftmaxDimHalfspace(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Softmax activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
class Exp(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||||
|
x_min = torch.amin(x, dim=-1, keepdim=True)
|
||||||
|
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
||||||
|
min=self.eps
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExpHalfspace(FeatureMapAct):
|
||||||
|
"""
|
||||||
|
Exp activation as in https://arxiv.org/abs/2402.04347
|
||||||
|
"""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||||
|
x_max = torch.amax(x, dim=-1, keepdim=True)
|
||||||
|
return torch.exp(x - x_max).clamp(min=self.eps)
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------
|
||||||
|
# Feature map MLPs
|
||||||
|
# ----------------
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapMLP(nn.Module):
|
||||||
|
"""
|
||||||
|
Learnable MLP in feature map.
|
||||||
|
|
||||||
|
Full feature map is like f(xW + b)
|
||||||
|
-> This is the `W` and (optional) `b` part
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int, # input dim
|
||||||
|
feature_dim: int, # output dim
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
skip_connection: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
zero_init: bool = False,
|
||||||
|
normal_init: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.feature_dim = feature_dim
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
self.skip_connection = skip_connection
|
||||||
|
self.bias = bias
|
||||||
|
self.zero_init = zero_init
|
||||||
|
self.normal_init = normal_init
|
||||||
|
self.init_weights_()
|
||||||
|
|
||||||
|
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
||||||
|
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
||||||
|
|
||||||
|
if self.normal_init:
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.normal_(self.layer)
|
||||||
|
|
||||||
|
if self.skip_connection:
|
||||||
|
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
||||||
|
assert self.head_dim == self.feature_dim, assertion_fail
|
||||||
|
|
||||||
|
def init_weights_(self):
|
||||||
|
"""
|
||||||
|
Initialize (W)eights and (b)iases
|
||||||
|
"""
|
||||||
|
self.layer = nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(self.num_heads, self.head_dim, self.feature_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.layer)
|
||||||
|
|
||||||
|
if self.bias:
|
||||||
|
self.bias = nn.Parameter(
|
||||||
|
torch.zeros(
|
||||||
|
(1, self.num_heads, 1, 1), # self.feature_dim),
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.bias)
|
||||||
|
else:
|
||||||
|
self.bias = 0.0 # hack
|
||||||
|
|
||||||
|
def zero_init_with_skip_(self):
|
||||||
|
"""
|
||||||
|
Initialize weights to zero matrix if skip connection
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.zeros_(self.layer)
|
||||||
|
|
||||||
|
def zero_init_(self):
|
||||||
|
"""
|
||||||
|
Initialize weights to identity matrix if no skip connection
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(self.layer.shape[0]):
|
||||||
|
try:
|
||||||
|
nn.init.eye_(self.layer[i])
|
||||||
|
except RuntimeError:
|
||||||
|
with torch.no_grad():
|
||||||
|
dtype = self.layer[i].dtype
|
||||||
|
weight = torch.eye(
|
||||||
|
*self.layer[i].shape,
|
||||||
|
requires_grad=self.layer[i].requires_grad,
|
||||||
|
device=self.layer[i].device,
|
||||||
|
)
|
||||||
|
self.layer[i] = weight.to(dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
"""
|
||||||
|
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
||||||
|
return x + _x if self.skip_connection else _x
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureMapAdapter(FeatureMapMLP):
|
||||||
|
"""
|
||||||
|
Learnable Feature map with bottleneck adapter
|
||||||
|
as in https://arxiv.org/abs/1902.00751
|
||||||
|
|
||||||
|
We don't use but could be fun to try
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int, *args, **kwargs):
|
||||||
|
kwargs["skip_connection"] = True
|
||||||
|
kwargs["bias"] = True
|
||||||
|
kwargs["zero_init"] = True
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def init_weights_(self):
|
||||||
|
"""
|
||||||
|
Initialize (W)eights and (b)iases
|
||||||
|
"""
|
||||||
|
kwargs = {"dtype": self.dtype, "device": self.device}
|
||||||
|
self.layer0 = nn.Parameter(
|
||||||
|
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
||||||
|
)
|
||||||
|
self.layer1 = nn.Parameter(
|
||||||
|
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.layer0)
|
||||||
|
nn.init.kaiming_uniform_(self.layer1)
|
||||||
|
|
||||||
|
self.bias0 = nn.Parameter(
|
||||||
|
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
||||||
|
)
|
||||||
|
self.bias1 = nn.Parameter(
|
||||||
|
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
||||||
|
)
|
||||||
|
nn.init.kaiming_uniform_(self.bias0)
|
||||||
|
nn.init.kaiming_uniform_(self.bias1)
|
||||||
|
|
||||||
|
def zero_init_with_skip_(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
nn.init.zeros_(self.layer0)
|
||||||
|
nn.init.zeros_(self.layer1)
|
||||||
|
nn.init.zeros_(self.bias0)
|
||||||
|
nn.init.zeros_(self.bias1)
|
||||||
|
|
||||||
|
def zero_init_(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
||||||
|
-> Down-project, apply nonlinearity, up-project; add skip connection
|
||||||
|
"""
|
||||||
|
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
||||||
|
_x = F.relu(_x)
|
||||||
|
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
||||||
|
return x + _x if self.skip_connection else _x
|
||||||
@@ -23,7 +23,7 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
_flash_attention_forward = None # Transformers v4.36
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
from ..model.rotary import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
# Causal linear attention dot product CUDA kernel from fast-transformers
|
# Causal linear attention dot product CUDA kernel from fast-transformers
|
||||||
from .linear_attention import (
|
from .linear_attention import (
|
||||||
@@ -32,9 +32,7 @@ from .linear_attention import (
|
|||||||
causal_dot_product,
|
causal_dot_product,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger(
|
LOG = logging.getLogger(__name__)
|
||||||
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------------
|
# ----------------------
|
||||||
@@ -11,9 +11,7 @@ import torch.nn.functional as F
|
|||||||
from .linear_attention import LinearAttentionState
|
from .linear_attention import LinearAttentionState
|
||||||
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
LOG = logging.getLogger(
|
LOG = logging.getLogger(__name__)
|
||||||
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
from thunderkittens import hedgehog as tk_window_hedgehog_attention
|
||||||
@@ -22,7 +22,8 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
_flash_attention_forward = None # Transformers v4.36
|
_flash_attention_forward = None # Transformers v4.36
|
||||||
|
|
||||||
from ..model.rotary import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
from .linear_attention import softmax_attention
|
from .linear_attention import softmax_attention
|
||||||
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
|
||||||
@@ -1,336 +0,0 @@
|
|||||||
"""
|
|
||||||
Learnable linear attention feature map classes and functions
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize feature map final activation for linear attention
|
|
||||||
"""
|
|
||||||
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize feature map final activation for linear attention
|
|
||||||
"""
|
|
||||||
if name == "softmax_dim" and fullspace:
|
|
||||||
return SoftmaxDim(**kwargs)
|
|
||||||
elif name == "softmax_dim" and not fullspace:
|
|
||||||
return SoftmaxDimHalfspace(**kwargs)
|
|
||||||
elif name == "exp_dim" and fullspace:
|
|
||||||
return Exp(**kwargs)
|
|
||||||
elif name == "exp_dim" and not fullspace:
|
|
||||||
return ExpHalfspace(**kwargs)
|
|
||||||
elif name == "pos_elu":
|
|
||||||
return PosELU(**kwargs)
|
|
||||||
elif name == "relu":
|
|
||||||
return ReLU(**kwargs)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
def init_learned_kernel(name: str, **kwargs):
|
|
||||||
"""
|
|
||||||
Initialize feature map MLP for linear attention
|
|
||||||
"""
|
|
||||||
if name == "untied_head_einsum":
|
|
||||||
return FeatureMapMLP(**kwargs)
|
|
||||||
elif name == "untied_head_adapter":
|
|
||||||
return FeatureMapAdapter(**kwargs)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMap(nn.Module):
|
|
||||||
"""
|
|
||||||
Final 'activation' of feature map. Can probably be combined with
|
|
||||||
`FeatureMapMLP` below
|
|
||||||
|
|
||||||
Full feature map is like f(xW + b)
|
|
||||||
-> This is the `f` part
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
activation_name: str,
|
|
||||||
head_dim_idx: int = -1,
|
|
||||||
eps: float = 1e-12,
|
|
||||||
mlp: Optional[nn.Module] = None,
|
|
||||||
fullspace: bool = True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.head_dim_idx = head_dim_idx
|
|
||||||
self.eps = eps
|
|
||||||
self.mlp = mlp if mlp is not None else nn.Identity()
|
|
||||||
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
|
|
||||||
"""
|
|
||||||
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
|
|
||||||
|
|
||||||
def q_map(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Use for inference in case q and k feature maps differ
|
|
||||||
"""
|
|
||||||
return self.forward(*args, **kwargs)
|
|
||||||
|
|
||||||
def k_map(self, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
Use for inference in case q and k feature maps differ
|
|
||||||
"""
|
|
||||||
return self.forward(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# -----------------------
|
|
||||||
# Feature map activations
|
|
||||||
# -----------------------
|
|
||||||
class FeatureMapAct(nn.Module):
|
|
||||||
"""
|
|
||||||
Base class for feature map activations
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, eps: float = 1e-12):
|
|
||||||
super().__init__()
|
|
||||||
self.eps = eps
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
"""
|
|
||||||
x.shape is (batch_size, n_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class PosELU(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
1 + ELU activation as in https://arxiv.org/abs/2006.16236
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return (1 + F.elu(x)).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class ReLU(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
ReLU activation as in https://arxiv.org/abs/2103.13076
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return F.relu(x).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxDim(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return torch.cat(
|
|
||||||
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
|
|
||||||
).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class SoftmaxDimHalfspace(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Softmax activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
return torch.softmax(x, dim=-1).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
class Exp(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
|
||||||
x_min = torch.amin(x, dim=-1, keepdim=True)
|
|
||||||
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
|
|
||||||
min=self.eps
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExpHalfspace(FeatureMapAct):
|
|
||||||
"""
|
|
||||||
Exp activation as in https://arxiv.org/abs/2402.04347
|
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
|
||||||
x_max = torch.amax(x, dim=-1, keepdim=True)
|
|
||||||
return torch.exp(x - x_max).clamp(min=self.eps)
|
|
||||||
|
|
||||||
|
|
||||||
# ----------------
|
|
||||||
# Feature map MLPs
|
|
||||||
# ----------------
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMapMLP(nn.Module):
|
|
||||||
"""
|
|
||||||
Learnable MLP in feature map.
|
|
||||||
|
|
||||||
Full feature map is like f(xW + b)
|
|
||||||
-> This is the `W` and (optional) `b` part
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_dim: int, # input dim
|
|
||||||
feature_dim: int, # output dim
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
skip_connection: bool = False,
|
|
||||||
bias: bool = False,
|
|
||||||
zero_init: bool = False,
|
|
||||||
normal_init: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_dim = head_dim
|
|
||||||
self.feature_dim = feature_dim
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
self.skip_connection = skip_connection
|
|
||||||
self.bias = bias
|
|
||||||
self.zero_init = zero_init
|
|
||||||
self.normal_init = normal_init
|
|
||||||
self.init_weights_()
|
|
||||||
|
|
||||||
if self.zero_init: # Zero-out weights or set as identity post-initialization
|
|
||||||
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
|
|
||||||
|
|
||||||
if self.normal_init:
|
|
||||||
with torch.no_grad():
|
|
||||||
nn.init.normal_(self.layer)
|
|
||||||
|
|
||||||
if self.skip_connection:
|
|
||||||
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
|
|
||||||
assert self.head_dim == self.feature_dim, assertion_fail
|
|
||||||
|
|
||||||
def init_weights_(self):
|
|
||||||
"""
|
|
||||||
Initialize (W)eights and (b)iases
|
|
||||||
"""
|
|
||||||
self.layer = nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
(self.num_heads, self.head_dim, self.feature_dim),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.layer)
|
|
||||||
|
|
||||||
if self.bias:
|
|
||||||
self.bias = nn.Parameter(
|
|
||||||
torch.zeros(
|
|
||||||
(1, self.num_heads, 1, 1), # self.feature_dim),
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.bias)
|
|
||||||
else:
|
|
||||||
self.bias = 0.0 # hack
|
|
||||||
|
|
||||||
def zero_init_with_skip_(self):
|
|
||||||
"""
|
|
||||||
Initialize weights to zero matrix if skip connection
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
nn.init.zeros_(self.layer)
|
|
||||||
|
|
||||||
def zero_init_(self):
|
|
||||||
"""
|
|
||||||
Initialize weights to identity matrix if no skip connection
|
|
||||||
"""
|
|
||||||
with torch.no_grad():
|
|
||||||
for i in range(self.layer.shape[0]):
|
|
||||||
try:
|
|
||||||
nn.init.eye_(self.layer[i])
|
|
||||||
except RuntimeError:
|
|
||||||
with torch.no_grad():
|
|
||||||
dtype = self.layer[i].dtype
|
|
||||||
weight = torch.eye(
|
|
||||||
*self.layer[i].shape,
|
|
||||||
requires_grad=self.layer[i].requires_grad,
|
|
||||||
device=self.layer[i].device,
|
|
||||||
)
|
|
||||||
self.layer[i] = weight.to(dtype=dtype)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
"""
|
|
||||||
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
|
|
||||||
return x + _x if self.skip_connection else _x
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureMapAdapter(FeatureMapMLP):
|
|
||||||
"""
|
|
||||||
Learnable Feature map with bottleneck adapter
|
|
||||||
as in https://arxiv.org/abs/1902.00751
|
|
||||||
|
|
||||||
We don't use but could be fun to try
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, hidden_dim: int, *args, **kwargs):
|
|
||||||
kwargs["skip_connection"] = True
|
|
||||||
kwargs["bias"] = True
|
|
||||||
kwargs["zero_init"] = True
|
|
||||||
self.hidden_dim = hidden_dim
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def init_weights_(self):
|
|
||||||
"""
|
|
||||||
Initialize (W)eights and (b)iases
|
|
||||||
"""
|
|
||||||
kwargs = {"dtype": self.dtype, "device": self.device}
|
|
||||||
self.layer0 = nn.Parameter(
|
|
||||||
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
|
|
||||||
)
|
|
||||||
self.layer1 = nn.Parameter(
|
|
||||||
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.layer0)
|
|
||||||
nn.init.kaiming_uniform_(self.layer1)
|
|
||||||
|
|
||||||
self.bias0 = nn.Parameter(
|
|
||||||
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
|
|
||||||
)
|
|
||||||
self.bias1 = nn.Parameter(
|
|
||||||
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
|
|
||||||
)
|
|
||||||
nn.init.kaiming_uniform_(self.bias0)
|
|
||||||
nn.init.kaiming_uniform_(self.bias1)
|
|
||||||
|
|
||||||
def zero_init_with_skip_(self):
|
|
||||||
with torch.no_grad():
|
|
||||||
nn.init.zeros_(self.layer0)
|
|
||||||
nn.init.zeros_(self.layer1)
|
|
||||||
nn.init.zeros_(self.bias0)
|
|
||||||
nn.init.zeros_(self.bias1)
|
|
||||||
|
|
||||||
def zero_init_(self):
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
"""
|
|
||||||
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
|
|
||||||
-> Down-project, apply nonlinearity, up-project; add skip connection
|
|
||||||
"""
|
|
||||||
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
|
|
||||||
_x = F.relu(_x)
|
|
||||||
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
|
|
||||||
return x + _x if self.skip_connection else _x
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
# coding=utf-8
|
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
|
||||||
#
|
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
||||||
# and OPT implementations in this library. It has been modified from its
|
|
||||||
# original forms to accommodate minor architectural differences compared
|
|
||||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""
|
|
||||||
Rotary embeddings. Same as usual for Transformer models.
|
|
||||||
|
|
||||||
Note these are modified from HF Transformers v4.36, from:
|
|
||||||
- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py
|
|
||||||
- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123
|
|
||||||
"""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def get_rotary_embeddings(
|
|
||||||
rope_scaling_type: Optional[str] = None,
|
|
||||||
head_dim: int = 128,
|
|
||||||
max_position_embeddings: int = 4096,
|
|
||||||
rope_theta: float = 10000.0,
|
|
||||||
rope_scaling_factor: float = 1.0,
|
|
||||||
device: Optional[torch.device] = None,
|
|
||||||
) -> nn.Module:
|
|
||||||
"""Return rotary embedding object"""
|
|
||||||
if rope_scaling_type is None:
|
|
||||||
return RotaryEmbedding(
|
|
||||||
head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
base=rope_theta,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif rope_scaling_type == "linear":
|
|
||||||
return LinearScalingRotaryEmbedding(
|
|
||||||
head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
scaling_factor=rope_scaling_factor,
|
|
||||||
base=rope_theta,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
elif rope_scaling_type == "dynamic":
|
|
||||||
return DynamicNTKScalingRotaryEmbedding(
|
|
||||||
head_dim,
|
|
||||||
max_position_embeddings=max_position_embeddings,
|
|
||||||
scaling_factor=rope_scaling_factor,
|
|
||||||
base=rope_theta,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(
|
|
||||||
f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
|
|
||||||
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
||||||
if position_ids is not None:
|
|
||||||
cos, sin = cos[position_ids], sin[position_ids]
|
|
||||||
cos = cos.unsqueeze(unsqueeze_dim)
|
|
||||||
sin = sin.unsqueeze(unsqueeze_dim)
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
|
|
||||||
class RotaryEmbedding(nn.Module):
|
|
||||||
"""Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864"""
|
|
||||||
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dim = dim
|
|
||||||
self.max_position_embeddings = max_position_embeddings
|
|
||||||
self.base = base
|
|
||||||
inv_freq = 1.0 / (
|
|
||||||
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
|
||||||
self._set_cos_sin_cache(
|
|
||||||
seq_len=max_position_embeddings,
|
|
||||||
device=self.inv_freq.device,
|
|
||||||
dtype=torch.get_default_dtype(),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(
|
|
||||||
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
|
||||||
"""
|
|
||||||
Compute rotary embeddings
|
|
||||||
"""
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
|
||||||
|
|
||||||
return (
|
|
||||||
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
|
||||||
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers/models/llama/modeling_llama.py at v4.36
|
|
||||||
class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
|
||||||
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
base=10000,
|
|
||||||
device=None,
|
|
||||||
scaling_factor=1.0,
|
|
||||||
):
|
|
||||||
self.scaling_factor = scaling_factor
|
|
||||||
super().__init__(dim, max_position_embeddings, base, device)
|
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(
|
|
||||||
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
|
||||||
)
|
|
||||||
t = t / self.scaling_factor
|
|
||||||
|
|
||||||
freqs = torch.outer(t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers/models/llama/modeling_llama.py at v4.36
|
|
||||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
|
||||||
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
base=10000,
|
|
||||||
device=None,
|
|
||||||
scaling_factor=1.0,
|
|
||||||
):
|
|
||||||
self.scaling_factor = scaling_factor
|
|
||||||
super().__init__(dim, max_position_embeddings, base, device)
|
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
|
|
||||||
if seq_len > self.max_position_embeddings:
|
|
||||||
base = self.base * (
|
|
||||||
(self.scaling_factor * seq_len / self.max_position_embeddings)
|
|
||||||
- (self.scaling_factor - 1)
|
|
||||||
) ** (self.dim / (self.dim - 2))
|
|
||||||
inv_freq = 1.0 / (
|
|
||||||
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
|
||||||
)
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
|
|
||||||
t = torch.arange(
|
|
||||||
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
freqs = torch.outer(t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
|
||||||
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
|
||||||
@@ -11,7 +11,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
LlamaRotaryEmbedding,
|
LlamaRotaryEmbedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .attention import LolcatsLinearAttention
|
|
||||||
from .configuration_linear_llama import LinearLlamaConfig
|
from .configuration_linear_llama import LinearLlamaConfig
|
||||||
|
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
@@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer):
|
|||||||
|
|
||||||
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
def __init__(self, config: LinearLlamaConfig, layer_idx: int):
|
||||||
super().__init__(config, layer_idx)
|
super().__init__(config, layer_idx)
|
||||||
|
|
||||||
# Replace the attention layer with our custom attention
|
# Replace the attention layer with our custom attention
|
||||||
self.self_attn = LolcatsLinearAttention(
|
self.self_attn = convert_llama_attention(
|
||||||
base_attn=self.self_attn, # type: ignore
|
layer=self, attention_config=config.attention_config
|
||||||
layer_idx=layer_idx,
|
|
||||||
**config.attention_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -110,18 +108,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError("Missing config")
|
raise ValueError("Missing config")
|
||||||
|
|
||||||
# initialize the model with prior weights
|
# initialize a new model with config
|
||||||
new_model = cls(config=config)
|
new_model = cls(config=config)
|
||||||
|
|
||||||
del new_model.model # remove the default model
|
# remove the default model and lm_head
|
||||||
|
del new_model.model
|
||||||
|
del new_model.lm_head
|
||||||
|
|
||||||
|
# load converted model, lm_head, and vocab_size from llama model
|
||||||
new_model.model = convert_attention(
|
new_model.model = convert_attention(
|
||||||
model.model,
|
model.model,
|
||||||
attention_config=config.attention_config,
|
attention_config=config.attention_config,
|
||||||
train_attention=train_attention,
|
train_attention=train_attention,
|
||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
|
new_model.lm_head = model.lm_head
|
||||||
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
|
new_model.vocab_size = model.vocab_size
|
||||||
|
|
||||||
return new_model
|
return new_model
|
||||||
|
|
||||||
@@ -227,7 +229,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False):
|
|||||||
def convert_llama_attention(
|
def convert_llama_attention(
|
||||||
layer: nn.Module,
|
layer: nn.Module,
|
||||||
attention_config: dict,
|
attention_config: dict,
|
||||||
layers: list[nn.Module], # list of layers
|
layers: Optional[list[nn.Module]] = None, # list of layers
|
||||||
train_attention: bool = False,
|
train_attention: bool = False,
|
||||||
remove_base_attn: bool = True,
|
remove_base_attn: bool = True,
|
||||||
):
|
):
|
||||||
@@ -237,7 +239,7 @@ def convert_llama_attention(
|
|||||||
return get_attention(**attention_config)(
|
return get_attention(**attention_config)(
|
||||||
base_attn=layer.self_attn,
|
base_attn=layer.self_attn,
|
||||||
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
|
||||||
max_layer_idx=len(layers) - 1,
|
max_layer_idx=len(layers) - 1 if layers else None,
|
||||||
train_attention=train_attention,
|
train_attention=train_attention,
|
||||||
remove_base_attn=remove_base_attn,
|
remove_base_attn=remove_base_attn,
|
||||||
)
|
)
|
||||||
@@ -252,39 +254,41 @@ def get_attention(attention_type: str, **kwargs):
|
|||||||
kwargs["attention_type"] = attention_type
|
kwargs["attention_type"] = attention_type
|
||||||
|
|
||||||
if attention_type == "lolcats_llama":
|
if attention_type == "lolcats_llama":
|
||||||
from .attention import LolcatsLinearAttention
|
from .linear_attention import LolcatsLinearAttention
|
||||||
|
|
||||||
return partial(LolcatsLinearAttention, **kwargs)
|
return partial(LolcatsLinearAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_tk":
|
elif attention_type == "lolcats_llama_window_tk":
|
||||||
from .attention import LolcatsTKWindowAttention
|
from .linear_window_attention_tk import LolcatsTKWindowAttention
|
||||||
|
|
||||||
return partial(LolcatsTKWindowAttention, **kwargs)
|
return partial(LolcatsTKWindowAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_sw":
|
elif attention_type == "lolcats_llama_window_sw":
|
||||||
from .attention import LolcatsSlidingWindowAttention
|
from .linear_window_attention_sw import LolcatsSlidingWindowAttention
|
||||||
|
|
||||||
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
return partial(LolcatsSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_llama_window_sw_linear":
|
elif attention_type == "lolcats_llama_window_sw_linear":
|
||||||
from .attention import LolcatsLinearSlidingWindowAttention
|
from .linear_window_attention_sw_linear import (
|
||||||
|
LolcatsLinearSlidingWindowAttention,
|
||||||
|
)
|
||||||
|
|
||||||
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
|
||||||
|
|
||||||
# Experimental chunked linear attentions below
|
# Experimental chunked linear attentions below
|
||||||
elif attention_type == "lolcats_long_llama_window_tk":
|
elif attention_type == "lolcats_long_llama_window_tk":
|
||||||
from .attention import LolcatsTKWindowLongAttention
|
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
|
||||||
|
|
||||||
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
return partial(LolcatsTKWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
elif attention_type == "lolcats_long_llama_window_sw":
|
elif attention_type == "lolcats_long_llama_window_sw":
|
||||||
from .attention import LolcatsSlidingWindowLongAttention
|
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
|
||||||
|
|
||||||
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
return partial(LolcatsSlidingWindowLongAttention, **kwargs)
|
||||||
|
|
||||||
# TK generation build (requires Thunderkittens)
|
# TK generation build (requires Thunderkittens)
|
||||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
from .attention import LolcatsWindowAttentionTKGen
|
from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
|
||||||
|
|
||||||
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
return partial(LolcatsWindowAttentionTKGen, **kwargs)
|
||||||
|
|
||||||
@@ -302,28 +306,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
|||||||
|
|
||||||
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
|
||||||
elif "lolcats_llama_window_tk_gen" in attention_type:
|
elif "lolcats_llama_window_tk_gen" in attention_type:
|
||||||
from .attention import LinearAttentionTKWindowGenerationCache
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
return LinearAttentionTKWindowGenerationCache()
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
elif "llama_window_tk" in attention_type:
|
elif "llama_window_tk" in attention_type:
|
||||||
from .attention import LinearAttentionTKWindowCache
|
from .linear_window_attention_tk import LinearAttentionTKWindowCache
|
||||||
|
|
||||||
return LinearAttentionTKWindowCache()
|
return LinearAttentionTKWindowCache()
|
||||||
|
|
||||||
elif "llama_window_sw" in attention_type:
|
elif "llama_window_sw" in attention_type:
|
||||||
from .attention import LinearAttentionSlidingWindowCache
|
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
return LinearAttentionSlidingWindowCache()
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
elif "llama_window_sw_linear" in attention_type:
|
elif "llama_window_sw_linear" in attention_type:
|
||||||
from .attention import LinearAttentionSlidingWindowCache
|
from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
|
||||||
|
|
||||||
return LinearAttentionSlidingWindowCache()
|
return LinearAttentionSlidingWindowCache()
|
||||||
|
|
||||||
# TK generation build (requires Thunderkittens)
|
# TK generation build (requires Thunderkittens)
|
||||||
elif attention_type == "lolcats_llama_window_tk_gen":
|
elif attention_type == "lolcats_llama_window_tk_gen":
|
||||||
from .attention import LinearAttentionTKWindowGenerationCache
|
from .linear_window_attention_tk_gen import (
|
||||||
|
LinearAttentionTKWindowGenerationCache,
|
||||||
|
)
|
||||||
|
|
||||||
return LinearAttentionTKWindowGenerationCache()
|
return LinearAttentionTKWindowGenerationCache()
|
||||||
|
|
||||||
@@ -331,7 +339,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
|
|||||||
return past_key_values
|
return past_key_values
|
||||||
|
|
||||||
else:
|
else:
|
||||||
from .attention import LinearAttentionState
|
from .linear_attention import LinearAttentionState
|
||||||
|
|
||||||
return LinearAttentionState()
|
return LinearAttentionState()
|
||||||
|
|
||||||
@@ -346,3 +354,8 @@ def register_linear_llama():
|
|||||||
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
AutoConfig.register("linear_llama", LinearLlamaConfig)
|
||||||
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
|
||||||
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
|
||||||
|
|
||||||
|
# registering for auto classes to save files
|
||||||
|
LinearLlamaConfig.register_for_auto_class("AutoConfig")
|
||||||
|
LinearLlamaModel.register_for_auto_class("AutoModel")
|
||||||
|
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")
|
||||||
|
|||||||
Reference in New Issue
Block a user