Compare commits

..

10 Commits

Author SHA1 Message Date
NanoCode012
13d458d0ae feat: update readme with inference instructions 2025-02-06 21:29:36 +07:00
NanoCode012
ebd406af1d fix: lin_attn_mask in wrong dtype 2025-02-06 15:25:33 +07:00
NanoCode012
caa49a9d7d fix: use existing model config 2025-02-06 00:12:14 +07:00
NanoCode012
c15ea6b956 fix: load vocab_size 2025-02-05 23:46:59 +07:00
NanoCode012
578fa764c8 chore: moved feature map into linear attention 2025-02-05 19:40:11 +07:00
NanoCode012
0e6efaa10c fix: manually set auto-map 2025-02-05 19:35:15 +07:00
NanoCode012
c4cb622590 fix: remove redundant files 2025-02-05 19:34:06 +07:00
NanoCode012
0f82bd2d18 chore: improve instruction and made linearize optional 2025-02-05 19:33:15 +07:00
NanoCode012
49746b184f chore: flatten directory structure and register to autoclass to save 2025-02-05 19:17:57 +07:00
NanoCode012
9e1c4de13c fix: assign linear head instead of loading state dict 2025-02-05 18:24:31 +07:00
17 changed files with 407 additions and 639 deletions

View File

@@ -49,12 +49,9 @@ def do_linearize(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
for p in model.parameters(): for p in model.parameters():
p.requires_grad = False p.requires_grad = False
# load config
base_config = load_model_config(cfg)
# convert to linear llama # convert to linear llama
linear_llama_config = LinearLlamaConfig.from_llama( linear_llama_config = LinearLlamaConfig.from_llama(
base_config, cfg.attention_config model.config, cfg.attention_config
) )
model = LinearLlamaForCausalLM.from_llama( model = LinearLlamaForCausalLM.from_llama(
model, config=linear_llama_config, train_attention=True model, config=linear_llama_config, train_attention=True

View File

@@ -4,7 +4,17 @@ https://github.com/HazyResearch/lolcats/
### Usage ### Usage
TODO: Add instruction to install `causal_dot_product`. Install `causal_dot_product` CUDA kernel (check the README in the `csrc` directory):
```bash
cd src/axolotl/integrations/lolcats/linear_llama/csrc
# Edit `setup.py` to point to the correct CUDA capabilities L40-44
# nano setup.py
# Build the CUDA kernel
python setup.py install
```
Step 1: Step 1:
@@ -15,7 +25,9 @@ plugins:
linearize: true linearize: true
``` ```
Step 2: Remove the config above and finetune with lora with below possible targets. Run axolotl: `python -m axolotl.cli.convert_linear_attention config.yaml` TODO: change path CLI
Step 2: Remove the config `linearize: true` and finetune with lora with below possible targets.
```yaml ```yaml
lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"] lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
@@ -24,3 +36,9 @@ lora_target_modules: ["q_proj", "k_proj", "v_proj", "o_proj"]
# to allow this config to work with lora # to allow this config to work with lora
# unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*'] # unfrozen_parameters: ['.*feature_map_q.mlp.layer.*', '.*feature_map_k.mlp.layer.*', '.*window_factors.*']
``` ```
`axolotl train config.yaml --base-model={output_dir}/distilled --trust-remote-code --learning-rate=0.0001 # --wandb-project="..."`
Step 3: Run inference on the finetuned model
`axolotl inference config.yaml --lora-model-dir="{output_dir}" --trust-remote-code # --prompter="AlpacaPrompter"`

View File

@@ -44,4 +44,4 @@ class LinearAttentionArgs(BaseModel):
attention_config: AttentionConfig attention_config: AttentionConfig
linearize: bool linearize: Optional[bool] = False

View File

@@ -1,21 +0,0 @@
"""
Linear and linear attention + sliding window classes
"""
from .linear_attention import LinearAttentionState, LolcatsLinearAttention
from .linear_window_attention_sw import (
LinearAttentionSlidingWindowCache,
LolcatsSlidingWindowAttention,
)
from .linear_window_attention_sw_linear import LolcatsLinearSlidingWindowAttention
from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
from .linear_window_attention_tk import (
LinearAttentionTKWindowCache,
LolcatsTKWindowAttention,
)
from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
LolcatsWindowAttentionTKGen,
)
# Experimental chunk linear attentions
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention

View File

@@ -1,34 +0,0 @@
"""
Shared attention helpers
"""
import torch
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
The hidden states go from:
(batch, num_key_value_heads, seqlen, head_dim) to
(batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def mask_attention(
qk_dot: torch.Tensor, attn_mask: torch.Tensor, mask_value: float = -10000
) -> torch.Tensor:
"""
Apply attention mask (e.g., for padding)
"""
if len(attn_mask.shape) == 4: # attn_mask either (b, h, l, d) or (b, l)
return qk_dot.masked_fill(~attn_mask.bool(), mask_value)
else:
return qk_dot.masked_fill(~attn_mask[:, None, None, :].bool(), mask_value)

View File

@@ -64,6 +64,13 @@ class LinearLlamaConfig(LlamaConfig):
def __init__(self, attention_config: Optional[dict] = None, **kwargs): def __init__(self, attention_config: Optional[dict] = None, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
# Set auto_map
self.auto_map = {
"AutoConfig": "configuration_linear_llama.LinearLlamaConfig",
"AutoModel": "modeling_linear_llama.LinearLlamaModel",
"AutoModelForCausalLM": "modeling_linear_llama.LinearLlamaForCausalLM",
}
# Set default attention config if none provided # Set default attention config if none provided
self.attention_config = attention_config or {"attention_type": "softmax"} self.attention_config = attention_config or {"attention_type": "softmax"}

View File

@@ -7,6 +7,7 @@ from typing import Any, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from transformers.cache_utils import Cache from transformers.cache_utils import Cache
# Causal linear attention dot product CUDA kernel from fast-transformers # Causal linear attention dot product CUDA kernel from fast-transformers
@@ -15,9 +16,7 @@ try:
except ImportError: except ImportError:
fast_causal_dot_product = None fast_causal_dot_product = None
from ..model.feature_map import init_feature_map, init_learned_kernel from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
from ..model.rotary import apply_rotary_pos_emb
from .utils import repeat_kv
# ------------------- # -------------------
# Attention functions # Attention functions
@@ -366,7 +365,7 @@ class LolcatsLinearAttention(nn.Module):
..., None ..., None
] # b, 1, k_len, 1 ] # b, 1, k_len, 1
else: else:
lin_attn_mask = attention_mask[:, None, :, None] # b, 1, k_len, 1 lin_attn_mask = attention_mask.bool()[:, None, :, None] # b, 1, k_len, 1
k = k.masked_fill(~lin_attn_mask, 0) k = k.masked_fill(~lin_attn_mask, 0)
if past_key_value is not None: # Initialize states if past_key_value is not None: # Initialize states
@@ -523,3 +522,335 @@ class LinearAttentionState(Cache):
raise NotImplementedError( raise NotImplementedError(
"Reordering cache not implemented for LinearAttentionState" "Reordering cache not implemented for LinearAttentionState"
) )
# -------------------
# feature map functions
# -------------------
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
if name == "softmax_dim" and fullspace:
return SoftmaxDim(**kwargs)
elif name == "softmax_dim" and not fullspace:
return SoftmaxDimHalfspace(**kwargs)
elif name == "exp_dim" and fullspace:
return Exp(**kwargs)
elif name == "exp_dim" and not fullspace:
return ExpHalfspace(**kwargs)
elif name == "pos_elu":
return PosELU(**kwargs)
elif name == "relu":
return ReLU(**kwargs)
else:
raise NotImplementedError
def init_learned_kernel(name: str, **kwargs):
"""
Initialize feature map MLP for linear attention
"""
if name == "untied_head_einsum":
return FeatureMapMLP(**kwargs)
elif name == "untied_head_adapter":
return FeatureMapAdapter(**kwargs)
else:
raise NotImplementedError
class FeatureMap(nn.Module):
"""
Final 'activation' of feature map. Can probably be combined with
`FeatureMapMLP` below
Full feature map is like f(xW + b)
-> This is the `f` part
"""
def __init__(
self,
activation_name: str,
head_dim_idx: int = -1,
eps: float = 1e-12,
mlp: Optional[nn.Module] = None,
fullspace: bool = True,
):
super().__init__()
self.head_dim_idx = head_dim_idx
self.eps = eps
self.mlp = mlp if mlp is not None else nn.Identity()
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
"""
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
def q_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
def k_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
# -----------------------
# Feature map activations
# -----------------------
class FeatureMapAct(nn.Module):
"""
Base class for feature map activations
"""
def __init__(self, eps: float = 1e-12):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return x
class PosELU(FeatureMapAct):
"""
1 + ELU activation as in https://arxiv.org/abs/2006.16236
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return (1 + F.elu(x)).clamp(min=self.eps)
class ReLU(FeatureMapAct):
"""
ReLU activation as in https://arxiv.org/abs/2103.13076
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return F.relu(x).clamp(min=self.eps)
class SoftmaxDim(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.cat(
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
).clamp(min=self.eps)
class SoftmaxDimHalfspace(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.softmax(x, dim=-1).clamp(min=self.eps)
class Exp(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
x_min = torch.amin(x, dim=-1, keepdim=True)
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
min=self.eps
)
class ExpHalfspace(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
return torch.exp(x - x_max).clamp(min=self.eps)
# ----------------
# Feature map MLPs
# ----------------
class FeatureMapMLP(nn.Module):
"""
Learnable MLP in feature map.
Full feature map is like f(xW + b)
-> This is the `W` and (optional) `b` part
"""
def __init__(
self,
num_heads: int,
head_dim: int, # input dim
feature_dim: int, # output dim
dtype: torch.dtype,
device: torch.device,
skip_connection: bool = False,
bias: bool = False,
zero_init: bool = False,
normal_init: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.feature_dim = feature_dim
self.dtype = dtype
self.device = device
self.skip_connection = skip_connection
self.bias = bias
self.zero_init = zero_init
self.normal_init = normal_init
self.init_weights_()
if self.zero_init: # Zero-out weights or set as identity post-initialization
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
if self.normal_init:
with torch.no_grad():
nn.init.normal_(self.layer)
if self.skip_connection:
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
assert self.head_dim == self.feature_dim, assertion_fail
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
self.layer = nn.Parameter(
torch.zeros(
(self.num_heads, self.head_dim, self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.layer)
if self.bias:
self.bias = nn.Parameter(
torch.zeros(
(1, self.num_heads, 1, 1), # self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.bias)
else:
self.bias = 0.0 # hack
def zero_init_with_skip_(self):
"""
Initialize weights to zero matrix if skip connection
"""
with torch.no_grad():
nn.init.zeros_(self.layer)
def zero_init_(self):
"""
Initialize weights to identity matrix if no skip connection
"""
with torch.no_grad():
for i in range(self.layer.shape[0]):
try:
nn.init.eye_(self.layer[i])
except RuntimeError:
with torch.no_grad():
dtype = self.layer[i].dtype
weight = torch.eye(
*self.layer[i].shape,
requires_grad=self.layer[i].requires_grad,
device=self.layer[i].device,
)
self.layer[i] = weight.to(dtype=dtype)
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
"""
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
return x + _x if self.skip_connection else _x
class FeatureMapAdapter(FeatureMapMLP):
"""
Learnable Feature map with bottleneck adapter
as in https://arxiv.org/abs/1902.00751
We don't use but could be fun to try
"""
def __init__(self, hidden_dim: int, *args, **kwargs):
kwargs["skip_connection"] = True
kwargs["bias"] = True
kwargs["zero_init"] = True
self.hidden_dim = hidden_dim
super().__init__(*args, **kwargs)
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
kwargs = {"dtype": self.dtype, "device": self.device}
self.layer0 = nn.Parameter(
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
)
self.layer1 = nn.Parameter(
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.layer0)
nn.init.kaiming_uniform_(self.layer1)
self.bias0 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
)
self.bias1 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.bias0)
nn.init.kaiming_uniform_(self.bias1)
def zero_init_with_skip_(self):
with torch.no_grad():
nn.init.zeros_(self.layer0)
nn.init.zeros_(self.layer1)
nn.init.zeros_(self.bias0)
nn.init.zeros_(self.bias1)
def zero_init_(self):
raise NotImplementedError
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
-> Down-project, apply nonlinearity, up-project; add skip connection
"""
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
_x = F.relu(_x)
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
return x + _x if self.skip_connection else _x

View File

@@ -23,7 +23,7 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36 _flash_attention_forward = None # Transformers v4.36
from ..model.rotary import apply_rotary_pos_emb from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# Causal linear attention dot product CUDA kernel from fast-transformers # Causal linear attention dot product CUDA kernel from fast-transformers
from .linear_attention import ( from .linear_attention import (
@@ -32,9 +32,7 @@ from .linear_attention import (
causal_dot_product, causal_dot_product,
) )
LOG = logging.getLogger( LOG = logging.getLogger(__name__)
"axolotl.integrations.lolcats.linear_attention.linear_window_attention_sw_long"
)
# ---------------------- # ----------------------

View File

@@ -11,9 +11,7 @@ import torch.nn.functional as F
from .linear_attention import LinearAttentionState from .linear_attention import LinearAttentionState
from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
LOG = logging.getLogger( LOG = logging.getLogger(__name__)
"axolotl.integrations.lolcats.linear_attention.linear_attention_tk_gen"
)
try: try:
from thunderkittens import hedgehog as tk_window_hedgehog_attention from thunderkittens import hedgehog as tk_window_hedgehog_attention

View File

@@ -22,7 +22,8 @@ try:
except ModuleNotFoundError: except ModuleNotFoundError:
_flash_attention_forward = None # Transformers v4.36 _flash_attention_forward = None # Transformers v4.36
from ..model.rotary import apply_rotary_pos_emb from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from .linear_attention import softmax_attention from .linear_attention import softmax_attention
from .linear_window_attention_tk import LolcatsTKWindowAttention from .linear_window_attention_tk import LolcatsTKWindowAttention

View File

@@ -1,336 +0,0 @@
"""
Learnable linear attention feature map classes and functions
"""
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
def init_feature_map(name: str, mlp: nn.Module, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
return FeatureMap(activation_name=name, mlp=mlp, **kwargs)
def init_feature_map_act(name: str, fullspace: bool = True, **kwargs):
"""
Initialize feature map final activation for linear attention
"""
if name == "softmax_dim" and fullspace:
return SoftmaxDim(**kwargs)
elif name == "softmax_dim" and not fullspace:
return SoftmaxDimHalfspace(**kwargs)
elif name == "exp_dim" and fullspace:
return Exp(**kwargs)
elif name == "exp_dim" and not fullspace:
return ExpHalfspace(**kwargs)
elif name == "pos_elu":
return PosELU(**kwargs)
elif name == "relu":
return ReLU(**kwargs)
else:
raise NotImplementedError
def init_learned_kernel(name: str, **kwargs):
"""
Initialize feature map MLP for linear attention
"""
if name == "untied_head_einsum":
return FeatureMapMLP(**kwargs)
elif name == "untied_head_adapter":
return FeatureMapAdapter(**kwargs)
else:
raise NotImplementedError
class FeatureMap(nn.Module):
"""
Final 'activation' of feature map. Can probably be combined with
`FeatureMapMLP` below
Full feature map is like f(xW + b)
-> This is the `f` part
"""
def __init__(
self,
activation_name: str,
head_dim_idx: int = -1,
eps: float = 1e-12,
mlp: Optional[nn.Module] = None,
fullspace: bool = True,
):
super().__init__()
self.head_dim_idx = head_dim_idx
self.eps = eps
self.mlp = mlp if mlp is not None else nn.Identity()
self.activation = init_feature_map_act(activation_name, fullspace, eps=eps)
def forward(self, x: torch.Tensor, *mlp_args, **mlp_kwargs):
"""
Assume x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return self.activation(self.mlp(x, *mlp_args, **mlp_kwargs), x)
def q_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
def k_map(self, *args, **kwargs):
"""
Use for inference in case q and k feature maps differ
"""
return self.forward(*args, **kwargs)
# -----------------------
# Feature map activations
# -----------------------
class FeatureMapAct(nn.Module):
"""
Base class for feature map activations
"""
def __init__(self, eps: float = 1e-12):
super().__init__()
self.eps = eps
def forward(self, x: torch.Tensor, *args, **kwargs):
"""
x.shape is (batch_size, n_heads, seq_len, head_dim)
"""
return x
class PosELU(FeatureMapAct):
"""
1 + ELU activation as in https://arxiv.org/abs/2006.16236
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return (1 + F.elu(x)).clamp(min=self.eps)
class ReLU(FeatureMapAct):
"""
ReLU activation as in https://arxiv.org/abs/2103.13076
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return F.relu(x).clamp(min=self.eps)
class SoftmaxDim(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.cat(
[torch.softmax(x, dim=-1), torch.softmax(-x, dim=-1)], dim=-1
).clamp(min=self.eps)
class SoftmaxDimHalfspace(FeatureMapAct):
"""
Softmax activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
return torch.softmax(x, dim=-1).clamp(min=self.eps)
class Exp(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
x_min = torch.amin(x, dim=-1, keepdim=True)
return torch.cat([torch.exp(x - x_max), torch.exp(-x + x_min)], dim=-1).clamp(
min=self.eps
)
class ExpHalfspace(FeatureMapAct):
"""
Exp activation as in https://arxiv.org/abs/2402.04347
"""
def forward(self, x: torch.Tensor, *args, **kwargs):
x_max = torch.amax(x, dim=-1, keepdim=True)
return torch.exp(x - x_max).clamp(min=self.eps)
# ----------------
# Feature map MLPs
# ----------------
class FeatureMapMLP(nn.Module):
"""
Learnable MLP in feature map.
Full feature map is like f(xW + b)
-> This is the `W` and (optional) `b` part
"""
def __init__(
self,
num_heads: int,
head_dim: int, # input dim
feature_dim: int, # output dim
dtype: torch.dtype,
device: torch.device,
skip_connection: bool = False,
bias: bool = False,
zero_init: bool = False,
normal_init: bool = False,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.feature_dim = feature_dim
self.dtype = dtype
self.device = device
self.skip_connection = skip_connection
self.bias = bias
self.zero_init = zero_init
self.normal_init = normal_init
self.init_weights_()
if self.zero_init: # Zero-out weights or set as identity post-initialization
self.zero_init_with_skip_() if self.skip_connection else self.zero_init_()
if self.normal_init:
with torch.no_grad():
nn.init.normal_(self.layer)
if self.skip_connection:
assertion_fail = f"If self.skip_connection we need self.head_dim == self.feature_dim but self.head_dim is {self.head_dim} != self.feature_dim is {self.feature_dim}"
assert self.head_dim == self.feature_dim, assertion_fail
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
self.layer = nn.Parameter(
torch.zeros(
(self.num_heads, self.head_dim, self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.layer)
if self.bias:
self.bias = nn.Parameter(
torch.zeros(
(1, self.num_heads, 1, 1), # self.feature_dim),
dtype=self.dtype,
device=self.device,
)
)
nn.init.kaiming_uniform_(self.bias)
else:
self.bias = 0.0 # hack
def zero_init_with_skip_(self):
"""
Initialize weights to zero matrix if skip connection
"""
with torch.no_grad():
nn.init.zeros_(self.layer)
def zero_init_(self):
"""
Initialize weights to identity matrix if no skip connection
"""
with torch.no_grad():
for i in range(self.layer.shape[0]):
try:
nn.init.eye_(self.layer[i])
except RuntimeError:
with torch.no_grad():
dtype = self.layer[i].dtype
weight = torch.eye(
*self.layer[i].shape,
requires_grad=self.layer[i].requires_grad,
device=self.layer[i].device,
)
self.layer[i] = weight.to(dtype=dtype)
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
"""
_x = torch.einsum("hdf,bhld->bhlf", self.layer, x) + self.bias
return x + _x if self.skip_connection else _x
class FeatureMapAdapter(FeatureMapMLP):
"""
Learnable Feature map with bottleneck adapter
as in https://arxiv.org/abs/1902.00751
We don't use but could be fun to try
"""
def __init__(self, hidden_dim: int, *args, **kwargs):
kwargs["skip_connection"] = True
kwargs["bias"] = True
kwargs["zero_init"] = True
self.hidden_dim = hidden_dim
super().__init__(*args, **kwargs)
def init_weights_(self):
"""
Initialize (W)eights and (b)iases
"""
kwargs = {"dtype": self.dtype, "device": self.device}
self.layer0 = nn.Parameter(
torch.zeros((self.num_heads, self.head_dim, self.hidden_dim), **kwargs)
)
self.layer1 = nn.Parameter(
torch.zeros((self.num_heads, self.hidden_dim, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.layer0)
nn.init.kaiming_uniform_(self.layer1)
self.bias0 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.hidden_dim), **kwargs)
)
self.bias1 = nn.Parameter(
torch.zeros((1, self.num_heads, 1, self.feature_dim), **kwargs)
)
nn.init.kaiming_uniform_(self.bias0)
nn.init.kaiming_uniform_(self.bias1)
def zero_init_with_skip_(self):
with torch.no_grad():
nn.init.zeros_(self.layer0)
nn.init.zeros_(self.layer1)
nn.init.zeros_(self.bias0)
nn.init.zeros_(self.bias1)
def zero_init_(self):
raise NotImplementedError
def forward(self, x: torch.Tensor):
"""
Assume x.shape is (batch_size, num_heads, seq_len, head_dim)
-> Down-project, apply nonlinearity, up-project; add skip connection
"""
_x = torch.einsum("hde,bhld->bhle", self.layer0, x) + self.bias0
_x = F.relu(_x)
_x = torch.einsum("hef,bhle->bhlf", self.layer1, _x) + self.bias1
return x + _x if self.skip_connection else _x

View File

@@ -1,204 +0,0 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rotary embeddings. Same as usual for Transformer models.
Note these are modified from HF Transformers v4.36, from:
- transformers/models/llama/modeling_llama.py or transformers/models/mistral/modeling_mistral.py
- i.e., https://github.com/huggingface/transformers/blob/a7cab3c283312b8d4de5df3bbe719971e24f4281/src/transformers/models/llama/modeling_llama.py#L123
"""
from typing import Optional
import torch
import torch.nn as nn
def get_rotary_embeddings(
rope_scaling_type: Optional[str] = None,
head_dim: int = 128,
max_position_embeddings: int = 4096,
rope_theta: float = 10000.0,
rope_scaling_factor: float = 1.0,
device: Optional[torch.device] = None,
) -> nn.Module:
"""Return rotary embedding object"""
if rope_scaling_type is None:
return RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
device=device,
)
elif rope_scaling_type == "linear":
return LinearScalingRotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=rope_scaling_factor,
base=rope_theta,
device=device,
)
elif rope_scaling_type == "dynamic":
return DynamicNTKScalingRotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
scaling_factor=rope_scaling_factor,
base=rope_theta,
device=device,
)
else:
raise NotImplementedError(
f'Sorry rope_scaling_type == "{rope_scaling_type}" not implemented.'
)
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors."""
if position_ids is not None:
cos, sin = cos[position_ids], sin[position_ids]
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Modified from transformers.models.mistral.modeling_mistral (llama.modeling_llama at v4.36)
class RotaryEmbedding(nn.Module):
"""Original Rotary Embeddings from RoFormer https://arxiv.org/abs/2104.09864"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings,
device=self.inv_freq.device,
dtype=torch.get_default_dtype(),
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
"""
Compute rotary embeddings
"""
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
# Copied from transformers/models/llama/modeling_llama.py at v4.36
class LinearScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
# Copied from transformers/models/llama/modeling_llama.py at v4.36
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
device=None,
scaling_factor=1.0,
):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (
base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

View File

@@ -11,7 +11,7 @@
import logging import logging
from functools import partial from functools import partial
from typing import Any from typing import Any, Optional
from torch import nn from torch import nn
from tqdm import tqdm from tqdm import tqdm
@@ -23,7 +23,6 @@ from transformers.models.llama.modeling_llama import (
LlamaRotaryEmbedding, LlamaRotaryEmbedding,
) )
from .attention import LolcatsLinearAttention
from .configuration_linear_llama import LinearLlamaConfig from .configuration_linear_llama import LinearLlamaConfig
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -36,11 +35,10 @@ class LinearLlamaDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: LinearLlamaConfig, layer_idx: int): def __init__(self, config: LinearLlamaConfig, layer_idx: int):
super().__init__(config, layer_idx) super().__init__(config, layer_idx)
# Replace the attention layer with our custom attention # Replace the attention layer with our custom attention
self.self_attn = LolcatsLinearAttention( self.self_attn = convert_llama_attention(
base_attn=self.self_attn, # type: ignore layer=self, attention_config=config.attention_config
layer_idx=layer_idx,
**config.attention_config,
) )
@@ -110,18 +108,22 @@ class LinearLlamaForCausalLM(LlamaForCausalLM):
if config is None: if config is None:
raise ValueError("Missing config") raise ValueError("Missing config")
# initialize the model with prior weights # initialize a new model with config
new_model = cls(config=config) new_model = cls(config=config)
del new_model.model # remove the default model # remove the default model and lm_head
del new_model.model
del new_model.lm_head
# load converted model, lm_head, and vocab_size from llama model
new_model.model = convert_attention( new_model.model = convert_attention(
model.model, model.model,
attention_config=config.attention_config, attention_config=config.attention_config,
train_attention=train_attention, train_attention=train_attention,
remove_base_attn=remove_base_attn, remove_base_attn=remove_base_attn,
) )
new_model.lm_head = model.lm_head
new_model.lm_head.load_state_dict(model.lm_head.state_dict()) new_model.vocab_size = model.vocab_size
return new_model return new_model
@@ -227,7 +229,7 @@ def traverse_layers(model: nn.Module, verbose: bool = False):
def convert_llama_attention( def convert_llama_attention(
layer: nn.Module, layer: nn.Module,
attention_config: dict, attention_config: dict,
layers: list[nn.Module], # list of layers layers: Optional[list[nn.Module]] = None, # list of layers
train_attention: bool = False, train_attention: bool = False,
remove_base_attn: bool = True, remove_base_attn: bool = True,
): ):
@@ -237,7 +239,7 @@ def convert_llama_attention(
return get_attention(**attention_config)( return get_attention(**attention_config)(
base_attn=layer.self_attn, base_attn=layer.self_attn,
layer_idx=layer.self_attn.layer_idx, # Transformers v4.36 layer_idx=layer.self_attn.layer_idx, # Transformers v4.36
max_layer_idx=len(layers) - 1, max_layer_idx=len(layers) - 1 if layers else None,
train_attention=train_attention, train_attention=train_attention,
remove_base_attn=remove_base_attn, remove_base_attn=remove_base_attn,
) )
@@ -252,39 +254,41 @@ def get_attention(attention_type: str, **kwargs):
kwargs["attention_type"] = attention_type kwargs["attention_type"] = attention_type
if attention_type == "lolcats_llama": if attention_type == "lolcats_llama":
from .attention import LolcatsLinearAttention from .linear_attention import LolcatsLinearAttention
return partial(LolcatsLinearAttention, **kwargs) return partial(LolcatsLinearAttention, **kwargs)
elif attention_type == "lolcats_llama_window_tk": elif attention_type == "lolcats_llama_window_tk":
from .attention import LolcatsTKWindowAttention from .linear_window_attention_tk import LolcatsTKWindowAttention
return partial(LolcatsTKWindowAttention, **kwargs) return partial(LolcatsTKWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw": elif attention_type == "lolcats_llama_window_sw":
from .attention import LolcatsSlidingWindowAttention from .linear_window_attention_sw import LolcatsSlidingWindowAttention
return partial(LolcatsSlidingWindowAttention, **kwargs) return partial(LolcatsSlidingWindowAttention, **kwargs)
elif attention_type == "lolcats_llama_window_sw_linear": elif attention_type == "lolcats_llama_window_sw_linear":
from .attention import LolcatsLinearSlidingWindowAttention from .linear_window_attention_sw_linear import (
LolcatsLinearSlidingWindowAttention,
)
return partial(LolcatsLinearSlidingWindowAttention, **kwargs) return partial(LolcatsLinearSlidingWindowAttention, **kwargs)
# Experimental chunked linear attentions below # Experimental chunked linear attentions below
elif attention_type == "lolcats_long_llama_window_tk": elif attention_type == "lolcats_long_llama_window_tk":
from .attention import LolcatsTKWindowLongAttention from .linear_window_attention_tk_long import LolcatsTKWindowLongAttention
return partial(LolcatsTKWindowLongAttention, **kwargs) return partial(LolcatsTKWindowLongAttention, **kwargs)
elif attention_type == "lolcats_long_llama_window_sw": elif attention_type == "lolcats_long_llama_window_sw":
from .attention import LolcatsSlidingWindowLongAttention from .linear_window_attention_sw_long import LolcatsSlidingWindowLongAttention
return partial(LolcatsSlidingWindowLongAttention, **kwargs) return partial(LolcatsSlidingWindowLongAttention, **kwargs)
# TK generation build (requires Thunderkittens) # TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen": elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LolcatsWindowAttentionTKGen from .linear_window_attention_tk_gen import LolcatsWindowAttentionTKGen
return partial(LolcatsWindowAttentionTKGen, **kwargs) return partial(LolcatsWindowAttentionTKGen, **kwargs)
@@ -302,28 +306,32 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
# LOG.info(f'Returning attention cache based on attention_type == {attention_type}') # LOG.info(f'Returning attention cache based on attention_type == {attention_type}')
elif "lolcats_llama_window_tk_gen" in attention_type: elif "lolcats_llama_window_tk_gen" in attention_type:
from .attention import LinearAttentionTKWindowGenerationCache from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache() return LinearAttentionTKWindowGenerationCache()
elif "llama_window_tk" in attention_type: elif "llama_window_tk" in attention_type:
from .attention import LinearAttentionTKWindowCache from .linear_window_attention_tk import LinearAttentionTKWindowCache
return LinearAttentionTKWindowCache() return LinearAttentionTKWindowCache()
elif "llama_window_sw" in attention_type: elif "llama_window_sw" in attention_type:
from .attention import LinearAttentionSlidingWindowCache from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache() return LinearAttentionSlidingWindowCache()
elif "llama_window_sw_linear" in attention_type: elif "llama_window_sw_linear" in attention_type:
from .attention import LinearAttentionSlidingWindowCache from .linear_window_attention_sw import LinearAttentionSlidingWindowCache
return LinearAttentionSlidingWindowCache() return LinearAttentionSlidingWindowCache()
# TK generation build (requires Thunderkittens) # TK generation build (requires Thunderkittens)
elif attention_type == "lolcats_llama_window_tk_gen": elif attention_type == "lolcats_llama_window_tk_gen":
from .attention import LinearAttentionTKWindowGenerationCache from .linear_window_attention_tk_gen import (
LinearAttentionTKWindowGenerationCache,
)
return LinearAttentionTKWindowGenerationCache() return LinearAttentionTKWindowGenerationCache()
@@ -331,7 +339,7 @@ def get_attention_cache(attention_type: str, past_key_values: Any = None):
return past_key_values return past_key_values
else: else:
from .attention import LinearAttentionState from .linear_attention import LinearAttentionState
return LinearAttentionState() return LinearAttentionState()
@@ -346,3 +354,8 @@ def register_linear_llama():
AutoConfig.register("linear_llama", LinearLlamaConfig) AutoConfig.register("linear_llama", LinearLlamaConfig)
AutoModel.register(LinearLlamaConfig, LinearLlamaModel) AutoModel.register(LinearLlamaConfig, LinearLlamaModel)
AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM) AutoModelForCausalLM.register(LinearLlamaConfig, LinearLlamaForCausalLM)
# registering for auto classes to save files
LinearLlamaConfig.register_for_auto_class("AutoConfig")
LinearLlamaModel.register_for_auto_class("AutoModel")
LinearLlamaForCausalLM.register_for_auto_class("AutoModelForCausalLM")