more fixes to conversion

This commit is contained in:
Wing Lian
2025-01-20 10:51:24 -05:00
parent 38dfd3fadb
commit 623eaca740
5 changed files with 84 additions and 38 deletions

View File

View File

View File

@@ -1,3 +1,5 @@
import json
import math
import os
import re
from pathlib import Path
@@ -9,7 +11,9 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard
from safetensors.torch import save_file
from tqdm import tqdm
from transformers import AutoConfig
from transformers.utils import SAFE_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
def extract_layer_number(key):
@@ -90,14 +94,20 @@ def low_rank_decomposition(
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
final_rank = min(min(weight.shape), max_rank)
# Distribute S to both to improve numerical precision.
sqrt_S = torch.sqrt(torch.diag(S[:final_rank]))
L = sqrt_S @ Vh[:final_rank, :]
R = U[:, :final_rank] @ sqrt_S
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
return A.to(dtype), B.to(dtype)
def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
return L.to(dtype), R.to(dtype)
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
"""
@@ -106,25 +116,24 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
"""
device = "cuda" if torch.cuda.is_available() else "mps"
# rslora
scaling = alpha / math.sqrt(rank)
base_weight = avg_weight.to(device)
finetuned_weight = layer_weight.to(device)
final_weight = layer_weight.to(device)
# 1. Compute column norms and directions
# (shape: base_norms, finetuned_norms => (k,))
base_norms = torch.norm(base_weight, dim=0) + 1e-9
finetuned_norms = torch.norm(finetuned_weight, dim=0) + 1e-9
delta_first_pass = final_weight - base_weight
# shape (d, k)
base_dir = base_weight / base_norms
finetuned_dir = finetuned_weight / finetuned_norms
# 2. Delta direction
delta_dir = finetuned_dir - base_dir
delta_for_svd = delta_first_pass / scaling
# 3. Low-rank factorization of the delta direction
A, B = low_rank_decomposition(delta_dir, rank)
# The final magnitudes are just finetuned_norms
return A.cpu(), B.cpu(), finetuned_norms.cpu()
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
lora_weight = lora_B @ lora_A
weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling)
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
@@ -142,13 +151,13 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}"
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
yield layernorm_key, weight
elif "post_attention_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
loop_idx = layer_idx // recurse_layers
layer_idx = layer_idx % recurse_layers
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}"
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}.weight"
yield layernorm_key, weight
else:
yield key, weight
@@ -169,6 +178,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
yield lora_magnitude_key, lora_magnitude
def save_state_dict_to_safetensors(state_dict, save_directory):
os.makedirs(save_directory, exist_ok=True)
weights_name = SAFE_WEIGHTS_NAME
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
@@ -211,7 +221,20 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32):
del state_dict
if index is None:
path_to_weights = os.path.join(save_directory, weights_name)
logger.info(f"Model weights saved in {path_to_weights}")
else:
save_index_file = SAFE_WEIGHTS_INDEX_NAME
save_index_file = os.path.join(save_directory, save_index_file)
# Save the index as well
with open(save_index_file, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"):
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
@@ -230,17 +253,19 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
f"divisible by the recurse layers ({recurse_layers})"
)
config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha})
config.save_pretrained(output_dir)
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="mps", recurse_layers=recurse_layers):
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device="mps", recurse_layers=recurse_layers):
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# combine state dicts into a single state_dict

View File

@@ -1,6 +1,10 @@
import math
import torch
import torch.nn.functional as F
from torch import nn, transpose
from peft.utils import transpose
from torch import nn
class RelaxedRecursiveDoraLinear(nn.Module):
@@ -23,6 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
out_features: int,
B: int,
rank: int,
alpha: int,
fan_in_fan_out: bool = False,
bias: bool = True,
use_dora: bool = True,
@@ -41,9 +46,18 @@ class RelaxedRecursiveDoraLinear(nn.Module):
self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)])
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
# rslora
self.scaling = alpha / math.sqrt(rank)
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
weight = transpose(weight, self.fan_in_fan_out)
weight = weight + scaling * lora_weight
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
return weight_norm
def forward(self, x, loop_idx: int):
"""
@@ -58,14 +72,16 @@ class RelaxedRecursiveDoraLinear(nn.Module):
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
base_out: torch.Tensor = F.linear(x, transpose(w_base, self.fan_in_fan_out), self.bias)
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype)
w_dora_full: torch.Tensor = lora_B(lora_A(x_eye))
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach())
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]

View File

@@ -20,8 +20,9 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
Configuration for Relaxed Recursive Llama.
"""
recurse_layers: int
recurse_layers: int = 4
rank: int
alpha: int
class RelaxedRecursiveLlamaMLP(nn.Module):
@@ -31,9 +32,9 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, bias=config.mlp_bias)
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
@@ -58,16 +59,16 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, bias=config.attention_bias
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
)
def forward(
@@ -185,6 +186,8 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
class RelaxedRecursiveLlamaModel(LlamaModel):
config_class = RelaxedRecursiveLlamaConfig
def __init__(self, config):
super(LlamaModel, self).__init__(config)
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
@@ -313,6 +316,8 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
config_class = RelaxedRecursiveLlamaConfig
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = RelaxedRecursiveLlamaModel(config)