From 623eaca7408d3d3ba0a6601b2cdf103bf11cb93c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 10:51:24 -0500 Subject: [PATCH] more fixes to conversion --- src/axolotl/integrations/rrt/README.md | 0 src/axolotl/integrations/rrt/args.py | 0 src/axolotl/integrations/rrt/cli/convert.py | 77 ++++++++++++------- .../integrations/rrt/modeling/linear.py | 24 +++++- .../rrt/modeling/modeling_rrt_llama.py | 21 +++-- 5 files changed, 84 insertions(+), 38 deletions(-) create mode 100644 src/axolotl/integrations/rrt/README.md create mode 100644 src/axolotl/integrations/rrt/args.py diff --git a/src/axolotl/integrations/rrt/README.md b/src/axolotl/integrations/rrt/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/rrt/args.py b/src/axolotl/integrations/rrt/args.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index e9fc44562..f755b38e8 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -1,3 +1,5 @@ +import json +import math import os import re from pathlib import Path @@ -9,7 +11,9 @@ from huggingface_hub import snapshot_download, split_torch_state_dict_into_shard from safetensors.torch import save_file from tqdm import tqdm from transformers import AutoConfig -from transformers.utils import SAFE_WEIGHTS_NAME +from transformers.utils import SAFE_WEIGHTS_NAME, SAFE_WEIGHTS_INDEX_NAME + +from axolotl.integrations.rrt.modeling.modeling_rrt_llama import RelaxedRecursiveLlamaConfig def extract_layer_number(key): @@ -90,14 +94,20 @@ def low_rank_decomposition( U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False) - final_rank = min(min(weight.shape), max_rank) - # Distribute S to both to improve numerical precision. - sqrt_S = torch.sqrt(torch.diag(S[:final_rank])) - L = sqrt_S @ Vh[:final_rank, :] - R = U[:, :final_rank] @ sqrt_S + sqrt_S = torch.sqrt(torch.diag(S[:max_rank])) + A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols] + B = U[:, :max_rank] @ sqrt_S # shape: [rows, r] + + return A.to(dtype), B.to(dtype) + + +def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = weight + scaling * lora_weight + weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) + return weight_norm - return L.to(dtype), R.to(dtype) def decompose_delta_weight(layer_weight, avg_weight, alpha, rank): """ @@ -106,25 +116,24 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank): """ device = "cuda" if torch.cuda.is_available() else "mps" + # rslora + scaling = alpha / math.sqrt(rank) + base_weight = avg_weight.to(device) - finetuned_weight = layer_weight.to(device) + final_weight = layer_weight.to(device) - # 1. Compute column norms and directions - # (shape: base_norms, finetuned_norms => (k,)) - base_norms = torch.norm(base_weight, dim=0) + 1e-9 - finetuned_norms = torch.norm(finetuned_weight, dim=0) + 1e-9 + delta_first_pass = final_weight - base_weight - # shape (d, k) - base_dir = base_weight / base_norms - finetuned_dir = finetuned_weight / finetuned_norms - - # 2. Delta direction - delta_dir = finetuned_dir - base_dir + delta_for_svd = delta_first_pass / scaling # 3. Low-rank factorization of the delta direction - A, B = low_rank_decomposition(delta_dir, rank) - # The final magnitudes are just finetuned_norms - return A.cpu(), B.cpu(), finetuned_norms.cpu() + lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank) + + lora_weight = lora_B @ lora_A + + weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling) + + return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu() def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12): @@ -142,13 +151,13 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx loop_idx = layer_idx // recurse_layers layer_idx = layer_idx % recurse_layers - layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}" + layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}.weight" yield layernorm_key, weight elif "post_attention_layernorm" in key: # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx loop_idx = layer_idx // recurse_layers layer_idx = layer_idx % recurse_layers - layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}" + layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}.weight" yield layernorm_key, weight else: yield key, weight @@ -169,6 +178,7 @@ def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_re yield lora_magnitude_key, lora_magnitude def save_state_dict_to_safetensors(state_dict, save_directory): + os.makedirs(save_directory, exist_ok=True) weights_name = SAFE_WEIGHTS_NAME filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") @@ -211,7 +221,20 @@ def save_state_dict_to_safetensors(state_dict, save_directory): save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) -def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32): + del state_dict + + if index is None: + path_to_weights = os.path.join(save_directory, weights_name) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + +def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device="mps"): modules_to_recurse = [ "self_attn.q_proj", "self_attn.k_proj", @@ -230,17 +253,19 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank= f"divisible by the recurse layers ({recurse_layers})" ) + config = RelaxedRecursiveLlamaConfig.from_dict({**config.to_dict(), "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha}) + config.save_pretrained(output_dir) model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth")) # create a new state_dict to store the RRT model weights rrt_model_state_dict = {} - for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="mps", recurse_layers=recurse_layers): + for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers): rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() # now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff rrt_lora_state_dict = {} - for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device="mps", recurse_layers=recurse_layers): + for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device=device, recurse_layers=recurse_layers): rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() # combine state dicts into a single state_dict diff --git a/src/axolotl/integrations/rrt/modeling/linear.py b/src/axolotl/integrations/rrt/modeling/linear.py index 270e8875a..c5df20e57 100644 --- a/src/axolotl/integrations/rrt/modeling/linear.py +++ b/src/axolotl/integrations/rrt/modeling/linear.py @@ -1,6 +1,10 @@ +import math + import torch import torch.nn.functional as F -from torch import nn, transpose +from peft.utils import transpose +from torch import nn + class RelaxedRecursiveDoraLinear(nn.Module): @@ -23,6 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module): out_features: int, B: int, rank: int, + alpha: int, fan_in_fan_out: bool = False, bias: bool = True, use_dora: bool = True, @@ -41,9 +46,18 @@ class RelaxedRecursiveDoraLinear(nn.Module): self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]) self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]) + # rslora + self.scaling = alpha / math.sqrt(rank) if use_dora: self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)]) + def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: + # calculate L2 norm of weight matrix, column-wise + weight = transpose(weight, self.fan_in_fan_out) + weight = weight + scaling * lora_weight + weight_norm = torch.linalg.norm(weight, dim=1).to(weight.dtype) + return weight_norm + def forward(self, x, loop_idx: int): """ @@ -58,14 +72,16 @@ class RelaxedRecursiveDoraLinear(nn.Module): lora_B: torch.Tensor = self.lora_B_list[loop_idx] magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx] - base_out: torch.Tensor = F.linear(x, transpose(w_base, self.fan_in_fan_out), self.bias) + base_out: torch.Tensor = F.linear(x, w_base, self.bias) x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype) - w_dora_full: torch.Tensor = lora_B(lora_A(x_eye)) + tmp = F.linear(x_eye, lora_A) # [hidden_size, rank] + w_dora_full: torch.Tensor = F.linear(tmp, lora_B) + w_dora_full = w_dora_full.t() lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None) - w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach()) + w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling) w_dora_norm = w_dora_norm.detach() scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features] diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index 1679a52ed..6b4595435 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -20,8 +20,9 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig): Configuration for Relaxed Recursive Llama. """ - recurse_layers: int + recurse_layers: int = 4 rank: int + alpha: int class RelaxedRecursiveLlamaMLP(nn.Module): @@ -31,9 +32,9 @@ class RelaxedRecursiveLlamaMLP(nn.Module): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias) - self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias) - self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, bias=config.mlp_bias) + self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias) + self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias) + self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x, loop_idx: int): @@ -58,16 +59,16 @@ class RelaxedRecursiveLlamaAttention(nn.Module): self.is_causal = True self.q_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias + config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias ) self.k_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias + config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias ) self.v_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias + config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias ) self.o_proj = RelaxedRecursiveDoraLinear( - config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, bias=config.attention_bias + config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias ) def forward( @@ -185,6 +186,8 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module): class RelaxedRecursiveLlamaModel(LlamaModel): + config_class = RelaxedRecursiveLlamaConfig + def __init__(self, config): super(LlamaModel, self).__init__(config) self.recurse_loops = config.num_hidden_layers // config.recurse_layers @@ -313,6 +316,8 @@ class RelaxedRecursiveLlamaModel(LlamaModel): class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM): + config_class = RelaxedRecursiveLlamaConfig + def __init__(self, config): super(LlamaForCausalLM, self).__init__(config) self.model = RelaxedRecursiveLlamaModel(config)