Compare commits

..

10 Commits

Author SHA1 Message Date
Wing Lian
791c38dcc3 chore: lint 2025-01-24 13:29:54 -05:00
Wing Lian
0af78a9882 rescale the norm for lora 2025-01-24 13:11:26 -05:00
Wing Lian
fa5efbf235 don't scale delta before decomposing 2025-01-24 13:11:26 -05:00
Wing Lian
59a7ac427d make sure to scale too 2025-01-24 13:11:25 -05:00
Wing Lian
e3393042e5 hopefully fix the lora/dora logic 2025-01-24 13:11:25 -05:00
Wing Lian
08a4e8a7fb refactor a bit 2025-01-24 13:11:25 -05:00
Wing Lian
b582d340b0 save tokenizer too 2025-01-24 13:11:25 -05:00
Wing Lian
474ba1a1b8 chore: lint/formatting 2025-01-24 13:11:25 -05:00
Wing Lian
de771fcb05 fix convert logger and registration 2025-01-24 13:11:25 -05:00
Wing Lian
f32d429db5 fix import path to args 2025-01-24 13:11:25 -05:00
7 changed files with 384 additions and 145 deletions

View File

@@ -4,12 +4,8 @@ Axolotl Plugin for Relaxed Recursive Transformers
import logging
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig, \
RelaxedRecursiveLlamaModel, RelaxedRecursiveLlamaForCausalLM
LOG = logging.getLogger(__name__)
@@ -20,23 +16,10 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin):
"""
def get_input_args(self):
return "axolotl.integrations.rrt.RelaxedRecursiveTransformerArgs"
return "axolotl.integrations.rrt.args.RelaxedRecursiveTransformerArgs"
def register(self):
LOG.info(
"Registering Relaxed Recursive Transformers modeling with transformers"
)
register_rrt_model()
def register_rrt_model():
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register("llama-rrt", RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM)

View File

@@ -1,3 +1,7 @@
"""
Axolotl config args for Relaxed Recursive Transformers plugin
"""
from pydantic import BaseModel
@@ -5,4 +9,3 @@ class RelaxedRecursiveTransformerArgs(BaseModel):
"""
Arguments pertaining to the Relaxed Recursive Transformer model.
"""
...

View File

@@ -1,4 +1,8 @@
"""
cli script for converting a pretrained model to a relaxed recursive transformer model
"""
import json
import logging
import math
import os
import re
@@ -10,15 +14,19 @@ import torch
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig
from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
)
logger = logging.getLogger(__name__)
def extract_layer_number(key):
"""Extract layer number from parameter key."""
match = re.search(r'layers\.(\d+)\.', key)
match = re.search(r"layers\.(\d+)\.", key)
return int(match.group(1)) if match else None
@@ -30,28 +38,30 @@ def iter_parameter_weights(model_path, device="mps"):
:param device: Computing device
:return: generator yielding (parameter key, parameter weight, layer index) tuples
"""
shards = list(model_path.glob('model*.safetensors'))
shards = list(model_path.glob("model*.safetensors"))
if not shards:
raise ValueError(f"No model shards found in {model_path}")
for shard in tqdm(shards, desc="Processing shards"):
with safetensors.safe_open(shard, framework='pt', device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
with safetensors.safe_open(shard, framework="pt", device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12):
def iter_recursive_parameter_weights(
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict = {}
rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
None
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
continue
@@ -62,7 +72,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
# setup as storage for suffix with torch.stack
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
else:
rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu())
rrt_avg_model_state_dict[suffix].append(
weight.to(torch.float32).detach().cpu()
)
for module_name in modules_to_recurse:
for recurse_idx in range(recurse_layers):
@@ -73,8 +85,9 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
# compute the decomposed lora diff from the weight base to the actual weight for each module
def low_rank_decomposition(
weight: torch.Tensor, max_rank: int
weight: torch.Tensor, max_rank: int
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Decompose a 2D matrix into low-rank matrices L and R using SVD.
@@ -83,18 +96,19 @@ def low_rank_decomposition(
:param max_rank: The maximum rank of the decomposition
:return: A tuple of tensors (L, R)
"""
# pylint: disable=invalid-name
assert (
weight.dim() == 2
weight.dim() == 2
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
assert (
max_rank >= 1
max_rank >= 1
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
dtype = weight.dtype
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
# Distribute S to both to improve numerical precision.
# Distribute S to both to improve numerical precision
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
@@ -109,7 +123,7 @@ def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
return weight_norm
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
"""
Decompose the difference in directions (ΔV) via SVD,
and return (magnitudes, L, R).
@@ -122,36 +136,49 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
base_weight = avg_weight.to(device)
final_weight = layer_weight.to(device)
delta_first_pass = final_weight - base_weight
delta_for_svd = final_weight - base_weight
delta_for_svd = delta_first_pass / scaling
# Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name
delta_for_svd, rank
)
# 3. Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
if use_dora:
lora_weight = lora_B @ lora_A
weight_norm = get_weight_norm(
base_weight.to(lora_A.device), lora_weight, scaling
)
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
lora_weight = lora_B @ lora_A
# let's rescale the lora weight to have the same magnitude as the base weight
weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling)
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
return lora_A.cpu(), lora_B.cpu(), None
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
rrt_avg_model_state_dict = {}
def iter_dora_parameter_weights(
model_path,
avg_recursive_weights,
modules_to_recurse: list[str],
alpha,
rank,
device="mps",
recurse_layers=12,
use_dora=True,
):
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
None
(module for module in modules_to_recurse if module in key), None
)
if matched_module_name is None:
if "input_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
layernorm_key = (
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
)
yield layernorm_key, weight
elif "post_attention_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
@@ -169,22 +196,35 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
suffix = f"{layer_idx}.{matched_module_name}"
prefix = f"model.layers.{suffix}.weight_base"
avg_weight = avg_recursive_weights[prefix]
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank)
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
lora_magnitude_key = (
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
)
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
weight,
avg_weight,
alpha,
rank,
use_dora=use_dora,
)
yield lora_a_key, lora_a
yield lora_b_key, lora_b
yield lora_magnitude_key, lora_magnitude
if use_dora:
yield lora_magnitude_key, lora_magnitude
def save_state_dict_to_safetensors(state_dict, save_directory):
os.makedirs(save_directory, exist_ok=True)
weights_name = SAFE_WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
".safetensors", "{suffix}.safetensors"
)
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
)
# pylint: disable=duplicate-code
# Save index if sharded
index = None
if state_dict_split.is_sharded:
@@ -205,10 +245,10 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
if (
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
filename.startswith(weights_no_suffix)
and os.path.isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)
@@ -219,7 +259,9 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
shard[tensor] = state_dict[tensor].contiguous()
del state_dict[tensor]
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
save_file(
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
)
del state_dict
@@ -234,7 +276,24 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"):
def convert_llama_to_rrt(
model_name,
output_dir,
recurse_layers: int = 12,
rank=32,
alpha=32,
device=None,
use_dora=True,
):
if not device:
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
@@ -246,6 +305,7 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
]
config = AutoConfig.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers % recurse_layers != 0:
raise ValueError(
@@ -253,21 +313,41 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
f"divisible by the recurse layers ({recurse_layers})"
)
config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha})
config = RelaxedRecursiveLlamaConfig.from_dict(
{
**config.to_dict(),
"recurse_layers": recurse_layers,
"rank": rank,
"alpha": alpha,
"use_dora": use_dora,
}
)
config.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
logger.info(f"Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
logger.info("Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
logger.info(f"Calculating decomposed lora diff...")
logger.info("Calculating decomposed lora diff...")
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
for key, weight in iter_dora_parameter_weights(
model_path,
rrt_model_state_dict,
modules_to_recurse,
alpha=32,
rank=rank,
device=device,
recurse_layers=recurse_layers,
use_dora=use_dora,
):
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# combine state dicts into a single state_dict
@@ -279,10 +359,12 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
if __name__ == "__main__":
# meta-llama/Llama-3.2-1B has 16 hidden layers
if torch.backends.mps.is_available():
device = "mps"
elif torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512, device=device)
# meta-llama/Llama-3.2-3B has 28 hidden layers
convert_llama_to_rrt(
"meta-llama/Llama-3.2-3B",
"/tmp/rrt_model", # nosec
recurse_layers=4,
rank=256,
alpha=512,
use_dora=False,
)

View File

@@ -1,2 +1,25 @@
"""
module for modeling relaxed recursive transformers model
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
from .modeling_rrt_llama import (
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
def register_rrt_model():
pass
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -0,0 +1,16 @@
"""
module for custom configuration for relaxed recursive transformers model
"""
from transformers import LlamaConfig
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
model_type: str = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True

View File

@@ -1,3 +1,6 @@
"""
module for the shared linear layer for the relaxed recursive transformers model
"""
import math
import torch
@@ -6,7 +9,6 @@ from peft.utils import transpose
from torch import nn
class RelaxedRecursiveDoraLinear(nn.Module):
"""
A single linear layer that is "shared" across multiple loop iterations,
@@ -25,7 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
self,
in_features: int,
out_features: int,
B: int,
B: int, # pylint: disable=invalid-name
rank: int,
alpha: int,
fan_in_fan_out: bool = False,
@@ -33,7 +35,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
use_dora: bool = True,
):
super().__init__()
self.B = B
self.B = B # pylint: disable=invalid-name
self.fan_in_fan_out = fan_in_fan_out
self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
@@ -44,13 +46,19 @@ class RelaxedRecursiveDoraLinear(nn.Module):
else:
self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)])
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
)
self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
)
# rslora
self.scaling = alpha / math.sqrt(rank)
self.use_dora = use_dora
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
self.lora_magnitude_vector_list = nn.ParameterList(
[nn.Parameter(torch.ones(out_features)) for _ in range(B)]
)
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
@@ -66,27 +74,43 @@ class RelaxedRecursiveDoraLinear(nn.Module):
:param loop_idx:
:return:
"""
eps = 1e-6
w_base = self.weight_base
w_base = w_base.to(x.dtype)
lora_A: torch.Tensor = self.lora_A_list[loop_idx]
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name
loop_idx
]
lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name
loop_idx
]
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype)
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
if self.use_dora:
x_eye: torch.Tensor = torch.eye(
lora_A.shape[1], device=lora_A.device, dtype=x.dtype
)
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
w_dora_norm: torch.Tensor = self.get_weight_norm(
w_base, w_dora_full.detach(), self.scaling
)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(
0
) # shape [1, out_features]
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora
return base_out + lora_out * self.scaling
# scale the lora norm to prevent gradient explosion
orig_norm = torch.linalg.norm(w_base)
update_norm = torch.linalg.norm(lora_out)
scale = orig_norm / (update_norm + eps)
return base_out + lora_out * scale

View File

@@ -1,30 +1,31 @@
import logging
from typing import Tuple, Optional, Unpack, Callable, Union
from typing import Callable, Optional, Tuple, Union, Unpack
import torch
from torch import nn
from transformers import LlamaConfig, Cache, DynamicCache
from transformers import Cache, DynamicCache, LlamaConfig
from transformers.activations import ACT2FN
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__)
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
model_type = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True
# pylint: skip-file
# mypy: ignore-errors
class RelaxedRecursiveLlamaMLP(nn.Module):
@@ -34,13 +35,40 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.gate_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.up_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.down_proj = RelaxedRecursiveDoraLinear(
self.intermediate_size,
self.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), loop_idx)
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx),
loop_idx,
)
return down_proj
@@ -54,23 +82,51 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.hidden_size,
config.num_attention_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.num_attention_heads * self.head_dim,
config.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
def forward(
@@ -81,32 +137,46 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
loop_idx: int,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
query_states = (
self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
logger.warning(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
@@ -121,8 +191,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, loop_idx)
return attn_output, attn_weights
return attn_output, attn_weights # pylint: disable=return-value
class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
@@ -135,12 +204,24 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.hidden_size = config.hidden_size
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx)
self.self_attn = RelaxedRecursiveLlamaAttention(
config=config, layer_idx=layer_idx
)
self.mlp = RelaxedRecursiveLlamaMLP(config)
self.input_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)])
self.post_attention_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)])
self.input_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
self.post_attention_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
def forward(
self,
@@ -152,9 +233,13 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
@@ -196,9 +281,14 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.recurse_layers)]
[
RelaxedRecursiveLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.recurse_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
@@ -221,15 +311,25 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
@@ -244,16 +344,24 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds