Compare commits
16 Commits
reentrant-
...
relaxed-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
791c38dcc3 | ||
|
|
0af78a9882 | ||
|
|
fa5efbf235 | ||
|
|
59a7ac427d | ||
|
|
e3393042e5 | ||
|
|
08a4e8a7fb | ||
|
|
b582d340b0 | ||
|
|
474ba1a1b8 | ||
|
|
de771fcb05 | ||
|
|
f32d429db5 | ||
|
|
82005f8eeb | ||
|
|
b439ed3345 | ||
|
|
623eaca740 | ||
|
|
38dfd3fadb | ||
|
|
daa9408233 | ||
|
|
257231ac46 |
@@ -48,9 +48,9 @@ class BasePlugin:
|
||||
Initializes the BasePlugin.
|
||||
"""
|
||||
|
||||
def register(self, cfg): # pylint: disable=unused-argument
|
||||
def register(self): # pylint: disable=unused-argument
|
||||
"""
|
||||
Registers the plugin with the given configuration.
|
||||
Registers the plugin
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
@@ -274,6 +274,7 @@ class PluginManager:
|
||||
try:
|
||||
plugin = load_plugin(plugin_name)
|
||||
self.plugins[plugin_name] = plugin
|
||||
plugin.register()
|
||||
except ImportError:
|
||||
logging.error(f"Failed to load plugin: {plugin_name}")
|
||||
|
||||
|
||||
0
src/axolotl/integrations/rrt/README.md
Normal file
0
src/axolotl/integrations/rrt/README.md
Normal file
25
src/axolotl/integrations/rrt/__init__.py
Normal file
25
src/axolotl/integrations/rrt/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
Axolotl Plugin for Relaxed Recursive Transformers
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.rrt.modeling import register_rrt_model
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RelaxedRecursiveTransformerPlugin(BasePlugin):
|
||||
"""
|
||||
Plugin for Relaxed Recursive Transformers integration with Axolotl
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.rrt.args.RelaxedRecursiveTransformerArgs"
|
||||
|
||||
def register(self):
|
||||
LOG.info(
|
||||
"Registering Relaxed Recursive Transformers modeling with transformers"
|
||||
)
|
||||
register_rrt_model()
|
||||
11
src/axolotl/integrations/rrt/args.py
Normal file
11
src/axolotl/integrations/rrt/args.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Axolotl config args for Relaxed Recursive Transformers plugin
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RelaxedRecursiveTransformerArgs(BaseModel):
|
||||
"""
|
||||
Arguments pertaining to the Relaxed Recursive Transformer model.
|
||||
"""
|
||||
370
src/axolotl/integrations/rrt/cli/convert.py
Normal file
370
src/axolotl/integrations/rrt/cli/convert.py
Normal file
@@ -0,0 +1,370 @@
|
||||
"""
|
||||
cli script for converting a pretrained model to a relaxed recursive transformer model
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
|
||||
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
|
||||
RelaxedRecursiveLlamaConfig,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_layer_number(key):
|
||||
"""Extract layer number from parameter key."""
|
||||
match = re.search(r"layers\.(\d+)\.", key)
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
def iter_parameter_weights(model_path, device="mps"):
|
||||
"""
|
||||
iterator over parameter weights in the model shards
|
||||
|
||||
:param model_path: Path to model shards
|
||||
:param device: Computing device
|
||||
:return: generator yielding (parameter key, parameter weight, layer index) tuples
|
||||
"""
|
||||
shards = list(model_path.glob("model*.safetensors"))
|
||||
if not shards:
|
||||
raise ValueError(f"No model shards found in {model_path}")
|
||||
|
||||
for shard in tqdm(shards, desc="Processing shards"):
|
||||
with safetensors.safe_open(shard, framework="pt", device=device) as f:
|
||||
for key in f.keys():
|
||||
layer_idx = extract_layer_number(key)
|
||||
weight = f.get_tensor(key)
|
||||
yield key, weight, layer_idx
|
||||
|
||||
|
||||
def iter_recursive_parameter_weights(
|
||||
model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12
|
||||
):
|
||||
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
||||
# to avoid precision loss when averaging weights across layers
|
||||
rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {}
|
||||
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key), None
|
||||
)
|
||||
if matched_module_name is None:
|
||||
continue
|
||||
|
||||
recurse_idx = layer_idx % recurse_layers
|
||||
suffix = f"{recurse_idx}.{matched_module_name}"
|
||||
if rrt_avg_model_state_dict.get(suffix) is None:
|
||||
# setup as storage for suffix with torch.stack
|
||||
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
|
||||
else:
|
||||
rrt_avg_model_state_dict[suffix].append(
|
||||
weight.to(torch.float32).detach().cpu()
|
||||
)
|
||||
|
||||
for module_name in modules_to_recurse:
|
||||
for recurse_idx in range(recurse_layers):
|
||||
suffix = f"{recurse_idx}.{module_name}"
|
||||
prefix = f"model.layers.{suffix}"
|
||||
avg_weight = torch.stack(rrt_avg_model_state_dict[suffix]).mean(dim=0)
|
||||
yield f"{prefix}.weight_base", avg_weight
|
||||
|
||||
# compute the decomposed lora diff from the weight base to the actual weight for each module
|
||||
|
||||
|
||||
def low_rank_decomposition(
|
||||
weight: torch.Tensor, max_rank: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Decompose a 2D matrix into low-rank matrices L and R using SVD.
|
||||
|
||||
:param weight: The matrix to decompose, of shape (H, W)
|
||||
:param max_rank: The maximum rank of the decomposition
|
||||
:return: A tuple of tensors (L, R)
|
||||
"""
|
||||
# pylint: disable=invalid-name
|
||||
assert (
|
||||
weight.dim() == 2
|
||||
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
|
||||
assert (
|
||||
max_rank >= 1
|
||||
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
|
||||
|
||||
dtype = weight.dtype
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
||||
|
||||
# Distribute S to both to improve numerical precision
|
||||
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
|
||||
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
|
||||
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
|
||||
|
||||
return A.to(dtype), B.to(dtype)
|
||||
|
||||
|
||||
def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
|
||||
# calculate L2 norm of weight matrix, column-wise
|
||||
weight = weight + scaling * lora_weight
|
||||
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
||||
return weight_norm
|
||||
|
||||
|
||||
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
|
||||
"""
|
||||
Decompose the difference in directions (ΔV) via SVD,
|
||||
and return (magnitudes, L, R).
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
|
||||
# rslora
|
||||
scaling = alpha / math.sqrt(rank)
|
||||
|
||||
base_weight = avg_weight.to(device)
|
||||
final_weight = layer_weight.to(device)
|
||||
|
||||
delta_for_svd = final_weight - base_weight
|
||||
|
||||
# Low-rank factorization of the delta direction
|
||||
lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name
|
||||
delta_for_svd, rank
|
||||
)
|
||||
|
||||
if use_dora:
|
||||
lora_weight = lora_B @ lora_A
|
||||
weight_norm = get_weight_norm(
|
||||
base_weight.to(lora_A.device), lora_weight, scaling
|
||||
)
|
||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||
|
||||
# let's rescale the lora weight to have the same magnitude as the base weight
|
||||
|
||||
return lora_A.cpu(), lora_B.cpu(), None
|
||||
|
||||
|
||||
def iter_dora_parameter_weights(
|
||||
model_path,
|
||||
avg_recursive_weights,
|
||||
modules_to_recurse: list[str],
|
||||
alpha,
|
||||
rank,
|
||||
device="mps",
|
||||
recurse_layers=12,
|
||||
use_dora=True,
|
||||
):
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key), None
|
||||
)
|
||||
if matched_module_name is None:
|
||||
if "input_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = (
|
||||
f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
|
||||
)
|
||||
yield layernorm_key, weight
|
||||
elif "post_attention_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}.weight"
|
||||
yield layernorm_key, weight
|
||||
else:
|
||||
yield key, weight
|
||||
continue
|
||||
|
||||
# figure out the base weight layer for this key
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
suffix = f"{layer_idx}.{matched_module_name}"
|
||||
prefix = f"model.layers.{suffix}.weight_base"
|
||||
avg_weight = avg_recursive_weights[prefix]
|
||||
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
||||
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
||||
lora_magnitude_key = (
|
||||
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||
)
|
||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
|
||||
weight,
|
||||
avg_weight,
|
||||
alpha,
|
||||
rank,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
yield lora_a_key, lora_a
|
||||
yield lora_b_key, lora_b
|
||||
if use_dora:
|
||||
yield lora_magnitude_key, lora_magnitude
|
||||
|
||||
|
||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
weights_name = SAFE_WEIGHTS_NAME
|
||||
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
||||
".safetensors", "{suffix}.safetensors"
|
||||
)
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
# Save index if sharded
|
||||
index = None
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||
|
||||
if (
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
del state_dict[tensor]
|
||||
|
||||
save_file(
|
||||
shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}
|
||||
)
|
||||
|
||||
del state_dict
|
||||
|
||||
if index is None:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
else:
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, save_index_file)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
|
||||
def convert_llama_to_rrt(
|
||||
model_name,
|
||||
output_dir,
|
||||
recurse_layers: int = 12,
|
||||
rank=32,
|
||||
alpha=32,
|
||||
device=None,
|
||||
use_dora=True,
|
||||
):
|
||||
if not device:
|
||||
if torch.backends.mps.is_available():
|
||||
device = "mps"
|
||||
elif torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
modules_to_recurse = [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
"self_attn.v_proj",
|
||||
"self_attn.o_proj",
|
||||
"mlp.down_proj",
|
||||
"mlp.gate_proj",
|
||||
"mlp.up_proj",
|
||||
]
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
if num_hidden_layers % recurse_layers != 0:
|
||||
raise ValueError(
|
||||
f"The number of hidden layers ({num_hidden_layers}) in the model must be "
|
||||
f"divisible by the recurse layers ({recurse_layers})"
|
||||
)
|
||||
|
||||
config = RelaxedRecursiveLlamaConfig.from_dict(
|
||||
{
|
||||
**config.to_dict(),
|
||||
"recurse_layers": recurse_layers,
|
||||
"rank": rank,
|
||||
"alpha": alpha,
|
||||
"use_dora": use_dora,
|
||||
}
|
||||
)
|
||||
config.save_pretrained(output_dir)
|
||||
tokenizer.save_pretrained(output_dir)
|
||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||
|
||||
# create a new state_dict to store the RRT model weights
|
||||
rrt_model_state_dict = {}
|
||||
|
||||
logger.info("Calculating average recursive weights...")
|
||||
for key, weight in iter_recursive_parameter_weights(
|
||||
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
|
||||
):
|
||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
logger.info("Calculating decomposed lora diff...")
|
||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||
rrt_lora_state_dict = {}
|
||||
for key, weight in iter_dora_parameter_weights(
|
||||
model_path,
|
||||
rrt_model_state_dict,
|
||||
modules_to_recurse,
|
||||
alpha=32,
|
||||
rank=rank,
|
||||
device=device,
|
||||
recurse_layers=recurse_layers,
|
||||
use_dora=use_dora,
|
||||
):
|
||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
# combine state dicts into a single state_dict
|
||||
rrt_model_state_dict.update(rrt_lora_state_dict)
|
||||
|
||||
# save state dict as sharded safetensors to disk using split_torch_state_dict_into_shards
|
||||
save_state_dict_to_safetensors(rrt_model_state_dict, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||
# meta-llama/Llama-3.2-3B has 28 hidden layers
|
||||
convert_llama_to_rrt(
|
||||
"meta-llama/Llama-3.2-3B",
|
||||
"/tmp/rrt_model", # nosec
|
||||
recurse_layers=4,
|
||||
rank=256,
|
||||
alpha=512,
|
||||
use_dora=False,
|
||||
)
|
||||
25
src/axolotl/integrations/rrt/modeling/__init__.py
Normal file
25
src/axolotl/integrations/rrt/modeling/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""
|
||||
module for modeling relaxed recursive transformers model
|
||||
"""
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
|
||||
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||
from .modeling_rrt_llama import (
|
||||
RelaxedRecursiveLlamaForCausalLM,
|
||||
RelaxedRecursiveLlamaModel,
|
||||
)
|
||||
|
||||
|
||||
def register_rrt_model():
|
||||
"""
|
||||
Register Relaxed Recursive Transformers model with transformers
|
||||
"""
|
||||
|
||||
# Register configs
|
||||
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
|
||||
|
||||
# Register models
|
||||
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
|
||||
AutoModelForCausalLM.register(
|
||||
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
module for custom configuration for relaxed recursive transformers model
|
||||
"""
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
Configuration for Relaxed Recursive Llama.
|
||||
"""
|
||||
|
||||
model_type: str = "llama-rrt"
|
||||
recurse_layers: int = 4
|
||||
rank: int
|
||||
alpha: int
|
||||
use_dora: bool = True
|
||||
116
src/axolotl/integrations/rrt/modeling/linear.py
Normal file
116
src/axolotl/integrations/rrt/modeling/linear.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""
|
||||
module for the shared linear layer for the relaxed recursive transformers model
|
||||
"""
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from peft.utils import transpose
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
"""
|
||||
A single linear layer that is "shared" across multiple loop iterations,
|
||||
but each iteration has its own DoRA offsets (A_i, B_i, magnitude_i).
|
||||
|
||||
The constructor expects you to specify:
|
||||
- in_features, out_features
|
||||
- B: number of loop iterations (i.e., how many times we "unroll")
|
||||
- fan_in_fan_out: pass True if your underlying base weight is transposed, etc.
|
||||
|
||||
The forward(...) expects an additional argument "loop_idx" in [0..B-1],
|
||||
which picks out the iteration-specific DoRA offsets.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
B: int, # pylint: disable=invalid-name
|
||||
rank: int,
|
||||
alpha: int,
|
||||
fan_in_fan_out: bool = False,
|
||||
bias: bool = True,
|
||||
use_dora: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.B = B # pylint: disable=invalid-name
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
|
||||
self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
|
||||
|
||||
self.use_bias = bias
|
||||
if self.use_bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name
|
||||
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
|
||||
)
|
||||
self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name
|
||||
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
|
||||
)
|
||||
# rslora
|
||||
self.scaling = alpha / math.sqrt(rank)
|
||||
self.use_dora = use_dora
|
||||
if use_dora:
|
||||
self.lora_magnitude_vector_list = nn.ParameterList(
|
||||
[nn.Parameter(torch.ones(out_features)) for _ in range(B)]
|
||||
)
|
||||
|
||||
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
|
||||
# calculate L2 norm of weight matrix, column-wise
|
||||
weight = transpose(weight, self.fan_in_fan_out)
|
||||
weight = weight + scaling * lora_weight
|
||||
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
||||
return weight_norm
|
||||
|
||||
def forward(self, x, loop_idx: int):
|
||||
"""
|
||||
|
||||
:param x: hidden state of shape (batch_size, seq_len, in_features)
|
||||
:param loop_idx:
|
||||
:return:
|
||||
"""
|
||||
eps = 1e-6
|
||||
w_base = self.weight_base
|
||||
w_base = w_base.to(x.dtype)
|
||||
|
||||
lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name
|
||||
loop_idx
|
||||
]
|
||||
lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name
|
||||
loop_idx
|
||||
]
|
||||
|
||||
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
||||
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
|
||||
|
||||
if self.use_dora:
|
||||
x_eye: torch.Tensor = torch.eye(
|
||||
lora_A.shape[1], device=lora_A.device, dtype=x.dtype
|
||||
)
|
||||
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
|
||||
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
|
||||
w_dora_full = w_dora_full.t()
|
||||
|
||||
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
|
||||
w_dora_norm: torch.Tensor = self.get_weight_norm(
|
||||
w_base, w_dora_full.detach(), self.scaling
|
||||
)
|
||||
w_dora_norm = w_dora_norm.detach()
|
||||
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(
|
||||
0
|
||||
) # shape [1, out_features]
|
||||
|
||||
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
||||
return result_dora
|
||||
|
||||
# scale the lora norm to prevent gradient explosion
|
||||
orig_norm = torch.linalg.norm(w_base)
|
||||
update_norm = torch.linalg.norm(lora_out)
|
||||
scale = orig_norm / (update_norm + eps)
|
||||
|
||||
return base_out + lora_out * scale
|
||||
471
src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py
Normal file
471
src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py
Normal file
@@ -0,0 +1,471 @@
|
||||
import logging
|
||||
from typing import Callable, Optional, Tuple, Union, Unpack
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Cache, DynamicCache, LlamaConfig
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaRMSNorm,
|
||||
LlamaRotaryEmbedding,
|
||||
apply_rotary_pos_emb,
|
||||
eager_attention_forward,
|
||||
)
|
||||
|
||||
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
||||
|
||||
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pylint: skip-file
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
||||
super().__init__()
|
||||
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = RelaxedRecursiveDoraLinear(
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.mlp_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
self.up_proj = RelaxedRecursiveDoraLinear(
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.mlp_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
self.down_proj = RelaxedRecursiveDoraLinear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.mlp_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x, loop_idx: int):
|
||||
down_proj = self.down_proj(
|
||||
self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx),
|
||||
loop_idx,
|
||||
)
|
||||
return down_proj
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaAttention(nn.Module):
|
||||
"""
|
||||
A single attention layer of the Relaxed Recursive Llama.
|
||||
"""
|
||||
|
||||
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
self.num_key_value_groups = (
|
||||
config.num_attention_heads // config.num_key_value_heads
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = RelaxedRecursiveDoraLinear(
|
||||
config.hidden_size,
|
||||
config.num_attention_heads * self.head_dim,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.attention_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
self.k_proj = RelaxedRecursiveDoraLinear(
|
||||
config.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.attention_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
self.v_proj = RelaxedRecursiveDoraLinear(
|
||||
config.hidden_size,
|
||||
config.num_key_value_heads * self.head_dim,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.attention_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
self.o_proj = RelaxedRecursiveDoraLinear(
|
||||
config.num_attention_heads * self.head_dim,
|
||||
config.hidden_size,
|
||||
recurse_loops,
|
||||
config.rank,
|
||||
config.alpha,
|
||||
bias=config.attention_bias,
|
||||
use_dora=config.use_dora,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
loop_idx: int,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, self.head_dim)
|
||||
|
||||
query_states = (
|
||||
self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
|
||||
)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
|
||||
)
|
||||
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and kwargs.get(
|
||||
"output_attentions", False
|
||||
):
|
||||
logger.warning(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
||||
self.config._attn_implementation
|
||||
]
|
||||
|
||||
attn_output, attn_weights = attention_interface(
|
||||
self,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
dropout=0.0 if not self.training else self.attention_dropout,
|
||||
scaling=self.scaling,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output, loop_idx)
|
||||
return attn_output, attn_weights # pylint: disable=return-value
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
|
||||
"""
|
||||
A single layer of the Relaxed Recursive Llama decoder.
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = RelaxedRecursiveLlamaAttention(
|
||||
config=config, layer_idx=layer_idx
|
||||
)
|
||||
|
||||
self.mlp = RelaxedRecursiveLlamaMLP(config)
|
||||
|
||||
self.input_layernorm_list = nn.ModuleList(
|
||||
[
|
||||
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
for _ in range(recurse_loops)
|
||||
]
|
||||
)
|
||||
self.post_attention_layernorm_list = nn.ModuleList(
|
||||
[
|
||||
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
for _ in range(recurse_loops)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
loop_idx: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Cache] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
position_embeddings: Optional[
|
||||
Tuple[torch.Tensor, torch.Tensor]
|
||||
] = None, # necessary, but kept here for BC
|
||||
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
loop_idx=loop_idx,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm_list[loop_idx](hidden_states)
|
||||
hidden_states = self.mlp(hidden_states, loop_idx)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaModel(LlamaModel):
|
||||
config_class = RelaxedRecursiveLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super(LlamaModel, self).__init__(config)
|
||||
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(
|
||||
config.vocab_size, config.hidden_size, self.padding_idx
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
RelaxedRecursiveLlamaDecoderLayer(config, layer_idx)
|
||||
for layer_idx in range(config.recurse_layers)
|
||||
]
|
||||
)
|
||||
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=config)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Cache] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You must specify exactly one of input_ids or inputs_embeds"
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training and use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
if use_cache and past_key_values is None:
|
||||
past_key_values = DynamicCache()
|
||||
|
||||
if cache_position is None:
|
||||
past_seen_tokens = (
|
||||
past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
)
|
||||
cache_position = torch.arange(
|
||||
past_seen_tokens,
|
||||
past_seen_tokens + inputs_embeds.shape[1],
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = cache_position.unsqueeze(0)
|
||||
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask,
|
||||
inputs_embeds,
|
||||
cache_position,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
# create position embeddings to be shared across the decoder layers
|
||||
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
|
||||
for loop_idx in range(self.recurse_loops):
|
||||
for decoder_layer in self.layers[: self.config.recurse_layers]:
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
layer_outputs = self._gradient_checkpointing_func(
|
||||
decoder_layer.__call__,
|
||||
hidden_states,
|
||||
loop_idx,
|
||||
causal_mask,
|
||||
position_ids,
|
||||
past_key_values,
|
||||
output_attentions,
|
||||
use_cache,
|
||||
cache_position,
|
||||
position_embeddings,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
loop_idx,
|
||||
attention_mask=causal_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_values,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cache_position=cache_position,
|
||||
position_embeddings=position_embeddings,
|
||||
**flash_attn_kwargs,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
output = BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=past_key_values if use_cache else None,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
return output if return_dict else output.to_tuple()
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
|
||||
config_class = RelaxedRecursiveLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super(LlamaForCausalLM, self).__init__(config)
|
||||
self.model = RelaxedRecursiveLlamaModel(config)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_nb_trainable_parameters(self) -> tuple[int, int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and the number of all parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
lora_params = 0
|
||||
for name, param in self.named_parameters():
|
||||
num_params = param.numel()
|
||||
# if using DS Zero 3 and the weights are initialized empty
|
||||
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||
num_params = param.ds_numel
|
||||
|
||||
# Due to the design of 4bit linear layers from bitsandbytes
|
||||
# one needs to multiply the number of parameters by 2 to get
|
||||
# the correct number of parameters
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
if hasattr(param, "element_size"):
|
||||
num_bytes = param.element_size()
|
||||
elif not hasattr(param, "quant_storage"):
|
||||
num_bytes = 1
|
||||
else:
|
||||
num_bytes = param.quant_storage.itemsize
|
||||
num_params = num_params * 2 * num_bytes
|
||||
|
||||
all_param += num_params
|
||||
if param.requires_grad:
|
||||
trainable_params += num_params
|
||||
if "lora_" in name:
|
||||
lora_params += num_params
|
||||
|
||||
return trainable_params, all_param, lora_params
|
||||
Reference in New Issue
Block a user