more fixes to conversion
This commit is contained in:
0
src/axolotl/integrations/rrt/README.md
Normal file
0
src/axolotl/integrations/rrt/README.md
Normal file
0
src/axolotl/integrations/rrt/args.py
Normal file
0
src/axolotl/integrations/rrt/args.py
Normal file
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
@@ -9,7 +11,9 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig
|
||||
|
||||
|
||||
def extract_layer_number(key):
|
||||
@@ -90,14 +94,20 @@ def low_rank_decomposition(
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
||||
|
||||
final_rank = min(min(weight.shape), max_rank)
|
||||
|
||||
# Distribute S to both to improve numerical precision.
|
||||
sqrt_S = torch.sqrt(torch.diag(S[:final_rank]))
|
||||
L = sqrt_S @ Vh[:final_rank, :]
|
||||
R = U[:, :final_rank] @ sqrt_S
|
||||
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
|
||||
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
|
||||
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
|
||||
|
||||
return A.to(dtype), B.to(dtype)
|
||||
|
||||
|
||||
def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
|
||||
# calculate L2 norm of weight matrix, column-wise
|
||||
weight = weight + scaling * lora_weight
|
||||
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
||||
return weight_norm
|
||||
|
||||
return L.to(dtype), R.to(dtype)
|
||||
|
||||
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||
"""
|
||||
@@ -106,25 +116,24 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
|
||||
# rslora
|
||||
scaling = alpha / math.sqrt(rank)
|
||||
|
||||
base_weight = avg_weight.to(device)
|
||||
finetuned_weight = layer_weight.to(device)
|
||||
final_weight = layer_weight.to(device)
|
||||
|
||||
# 1. Compute column norms and directions
|
||||
# (shape: base_norms, finetuned_norms => (k,))
|
||||
base_norms = torch.norm(base_weight, dim=0) + 1e-9
|
||||
finetuned_norms = torch.norm(finetuned_weight, dim=0) + 1e-9
|
||||
delta_first_pass = final_weight - base_weight
|
||||
|
||||
# shape (d, k)
|
||||
base_dir = base_weight / base_norms
|
||||
finetuned_dir = finetuned_weight / finetuned_norms
|
||||
|
||||
# 2. Delta direction
|
||||
delta_dir = finetuned_dir - base_dir
|
||||
delta_for_svd = delta_first_pass / scaling
|
||||
|
||||
# 3. Low-rank factorization of the delta direction
|
||||
A, B = low_rank_decomposition(delta_dir, rank)
|
||||
# The final magnitudes are just finetuned_norms
|
||||
return A.cpu(), B.cpu(), finetuned_norms.cpu()
|
||||
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
|
||||
|
||||
lora_weight = lora_B @ lora_A
|
||||
|
||||
weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling)
|
||||
|
||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||
|
||||
|
||||
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
|
||||
@@ -142,13 +151,13 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}"
|
||||
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight"
|
||||
yield layernorm_key, weight
|
||||
elif "post_attention_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}"
|
||||
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}.weight"
|
||||
yield layernorm_key, weight
|
||||
else:
|
||||
yield key, weight
|
||||
@@ -169,6 +178,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re
|
||||
yield lora_magnitude_key, lora_magnitude
|
||||
|
||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
weights_name = SAFE_WEIGHTS_NAME
|
||||
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
@@ -211,7 +221,20 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
||||
|
||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32):
|
||||
del state_dict
|
||||
|
||||
if index is None:
|
||||
path_to_weights = os.path.join(save_directory, weights_name)
|
||||
logger.info(f"Model weights saved in {path_to_weights}")
|
||||
else:
|
||||
save_index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
save_index_file = os.path.join(save_directory, save_index_file)
|
||||
# Save the index as well
|
||||
with open(save_index_file, "w", encoding="utf-8") as f:
|
||||
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
|
||||
f.write(content)
|
||||
|
||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"):
|
||||
modules_to_recurse = [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
@@ -230,17 +253,19 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
||||
f"divisible by the recurse layers ({recurse_layers})"
|
||||
)
|
||||
|
||||
config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha})
|
||||
config.save_pretrained(output_dir)
|
||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||
|
||||
# create a new state_dict to store the RRT model weights
|
||||
rrt_model_state_dict = {}
|
||||
|
||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="mps", recurse_layers=recurse_layers):
|
||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers):
|
||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||
rrt_lora_state_dict = {}
|
||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device="mps", recurse_layers=recurse_layers):
|
||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers):
|
||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
# combine state dicts into a single state_dict
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, transpose
|
||||
from peft.utils import transpose
|
||||
from torch import nn
|
||||
|
||||
|
||||
|
||||
class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
@@ -23,6 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
out_features: int,
|
||||
B: int,
|
||||
rank: int,
|
||||
alpha: int,
|
||||
fan_in_fan_out: bool = False,
|
||||
bias: bool = True,
|
||||
use_dora: bool = True,
|
||||
@@ -41,9 +46,18 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
|
||||
self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)])
|
||||
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
|
||||
# rslora
|
||||
self.scaling = alpha / math.sqrt(rank)
|
||||
if use_dora:
|
||||
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
|
||||
|
||||
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
|
||||
# calculate L2 norm of weight matrix, column-wise
|
||||
weight = transpose(weight, self.fan_in_fan_out)
|
||||
weight = weight + scaling * lora_weight
|
||||
weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype)
|
||||
return weight_norm
|
||||
|
||||
def forward(self, x, loop_idx: int):
|
||||
"""
|
||||
|
||||
@@ -58,14 +72,16 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
|
||||
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
|
||||
|
||||
base_out: torch.Tensor = F.linear(x, transpose(w_base, self.fan_in_fan_out), self.bias)
|
||||
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
||||
|
||||
x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype)
|
||||
w_dora_full: torch.Tensor = lora_B(lora_A(x_eye))
|
||||
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
|
||||
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
|
||||
w_dora_full = w_dora_full.t()
|
||||
|
||||
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
|
||||
|
||||
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach())
|
||||
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
|
||||
w_dora_norm = w_dora_norm.detach()
|
||||
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
|
||||
|
||||
|
||||
@@ -20,8 +20,9 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
Configuration for Relaxed Recursive Llama.
|
||||
"""
|
||||
|
||||
recurse_layers: int
|
||||
recurse_layers: int = 4
|
||||
rank: int
|
||||
alpha: int
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||
@@ -31,9 +32,9 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias)
|
||||
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias)
|
||||
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, bias=config.mlp_bias)
|
||||
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
|
||||
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
|
||||
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
|
||||
self.act_fn = ACT2FN[config.hidden_act]
|
||||
|
||||
def forward(self, x, loop_idx: int):
|
||||
@@ -58,16 +59,16 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
|
||||
self.is_causal = True
|
||||
|
||||
self.q_proj = RelaxedRecursiveDoraLinear(
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
|
||||
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
||||
)
|
||||
self.k_proj = RelaxedRecursiveDoraLinear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
||||
)
|
||||
self.v_proj = RelaxedRecursiveDoraLinear(
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
|
||||
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
||||
)
|
||||
self.o_proj = RelaxedRecursiveDoraLinear(
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, bias=config.attention_bias
|
||||
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
||||
)
|
||||
|
||||
def forward(
|
||||
@@ -185,6 +186,8 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaModel(LlamaModel):
|
||||
config_class = RelaxedRecursiveLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super(LlamaModel, self).__init__(config)
|
||||
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
@@ -313,6 +316,8 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
|
||||
|
||||
|
||||
class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
|
||||
config_class = RelaxedRecursiveLlamaConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super(LlamaForCausalLM, self).__init__(config)
|
||||
self.model = RelaxedRecursiveLlamaModel(config)
|
||||
|
||||
Reference in New Issue
Block a user