From 38dfd3fadba227c49335fa2a683dd2d9632659e4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 19 Jan 2025 22:19:33 -0500 Subject: [PATCH] wip conversion cli --- src/axolotl/integrations/rrt/cli/convert.py | 193 ++++++++++++++++-- .../rrt/modeling/modeling_rrt_llama.py | 12 +- 2 files changed, 181 insertions(+), 24 deletions(-) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index e3674b4be..e9fc44562 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -1,11 +1,15 @@ +import os +import re from pathlib import Path -from typing import re +from typing import Tuple import safetensors import torch -from huggingface_hub import snapshot_download +from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards +from safetensors.torch import save_file from tqdm import tqdm from transformers import AutoConfig +from transformers.utils import SAFE_WEIGHTS_NAME def extract_layer_number(key): @@ -14,7 +18,7 @@ def extract_layer_number(key): return int(match.group(1)) if match else None -def iter_parameter_weights(model_path, device="cpu"): +def iter_parameter_weights(model_path, device="mps"): """ iterator over parameter weights in the model shards @@ -33,7 +37,7 @@ def iter_parameter_weights(model_path, device="cpu"): weight = f.get_tensor(key) yield key, weight, layer_idx -def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="cpu", recurse_layers=12): +def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12): # setup placeholder state_dict for recursive weights, need to keep in float32 precision # to avoid precision loss when averaging weights across layers rrt_avg_model_state_dict = {} @@ -46,30 +50,168 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], None ) if matched_module_name is None: - if "input_layernorm" in key: - # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx - yield - else: - yield key, weight + continue recurse_idx = layer_idx % recurse_layers suffix = f"{recurse_idx}.{matched_module_name}" - prefix = f"model.layers.{suffix}." if rrt_avg_model_state_dict.get(suffix) is None: # setup as storage for suffix with torch.stack - rrt_avg_model_state_dict[suffix] = torch.stack([weight.to(torch.float32).detach().cpu()]) + rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()] else: - rrt_avg_model_state_dict[suffix] = torch.cat([rrt_avg_model_state_dict[suffix], weight.to(torch.float32).detach().cpu()]) + rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu()) for module_name in modules_to_recurse: for recurse_idx in range(recurse_layers): suffix = f"{recurse_idx}.{module_name}" - prefix = f"model.layers.{suffix}." - avg_weight = rrt_avg_model_state_dict[suffix].mean(dim=0) - yield f"{prefix}.weight", avg_weight + prefix = f"model.layers.{suffix}" + avg_weight = torch.stack(rrt_avg_model_state_dict[suffix]).mean(dim=0) + yield f"{prefix}.weight_base", avg_weight + + # compute the decomposed lora diff from the weight base to the actual weight for each module + +def low_rank_decomposition( + weight: torch.Tensor, max_rank: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Decompose a 2D matrix into low-rank matrices L and R using SVD. + + :param weight: The matrix to decompose, of shape (H, W) + :param max_rank: The maximum rank of the decomposition + :return: A tuple of tensors (L, R) + """ + assert ( + weight.dim() == 2 + ), f"Only support 2D matrix, but input has {weight.dim()} dimensions." + assert ( + max_rank >= 1 + ), f"Maximum rank must be a positive integer, but input max_rank={max_rank}." + + dtype = weight.dtype + + U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False) + + final_rank = min(min(weight.shape), max_rank) + + # Distribute S to both to improve numerical precision. + sqrt_S = torch.sqrt(torch.diag(S[:final_rank])) + L = sqrt_S @ Vh[:final_rank, :] + R = U[:, :final_rank] @ sqrt_S + + return L.to(dtype), R.to(dtype) + +def decompose_delta_weight(layer_weight, avg_weight, alpha, rank): + """ + Decompose the difference in directions (ΔV) via SVD, + and return (magnitudes, L, R). + """ + device = "cuda" if torch.cuda.is_available() else "mps" + + base_weight = avg_weight.to(device) + finetuned_weight = layer_weight.to(device) + + # 1. Compute column norms and directions + # (shape: base_norms, finetuned_norms => (k,)) + base_norms = torch.norm(base_weight, dim=0) + 1e-9 + finetuned_norms = torch.norm(finetuned_weight, dim=0) + 1e-9 + + # shape (d, k) + base_dir = base_weight / base_norms + finetuned_dir = finetuned_weight / finetuned_norms + + # 2. Delta direction + delta_dir = finetuned_dir - base_dir + + # 3. Low-rank factorization of the delta direction + A, B = low_rank_decomposition(delta_dir, rank) + # The final magnitudes are just finetuned_norms + return A.cpu(), B.cpu(), finetuned_norms.cpu() -def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12): +def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12): + rrt_avg_model_state_dict = {} + + # iterate over all parameter weights in the model shards + for key, weight, layer_idx in iter_parameter_weights(model_path): + # get the matching module name in modules_to_recurse for the current parameter key + matched_module_name = next( + (module for module in modules_to_recurse if module in key), + None + ) + if matched_module_name is None: + if "input_layernorm" in key: + # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx + loop_idx = layer_idx // recurse_layers + layer_idx = layer_idx % recurse_layers + layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}" + yield layernorm_key, weight + elif "post_attention_layernorm" in key: + # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx + loop_idx = layer_idx // recurse_layers + layer_idx = layer_idx % recurse_layers + layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}" + yield layernorm_key, weight + else: + yield key, weight + continue + + # figure out the base weight layer for this key + loop_idx = layer_idx // recurse_layers + layer_idx = layer_idx % recurse_layers + suffix = f"{layer_idx}.{matched_module_name}" + prefix = f"model.layers.{suffix}.weight_base" + avg_weight = avg_recursive_weights[prefix] + lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}" + lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}" + lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}" + lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank) + yield lora_a_key, lora_a + yield lora_b_key, lora_b + yield lora_magnitude_key, lora_magnitude + +def save_state_dict_to_safetensors(state_dict, save_directory): + weights_name = SAFE_WEIGHTS_NAME + + filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors") + state_dict_split = split_torch_state_dict_into_shards( + state_dict, filename_pattern=filename_pattern, max_shard_size="1GB" + ) + # Save index if sharded + index = None + if state_dict_split.is_sharded: + index = { + "metadata": state_dict_split.metadata, + "weight_map": state_dict_split.tensor_to_filename, + } + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in state_dict_split.filename_to_tensors.keys() + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + filename_to_tensors = state_dict_split.filename_to_tensors.items() + for shard_file, tensors in filename_to_tensors: + shard = {} + for tensor in tensors: + shard[tensor] = state_dict[tensor].contiguous() + del state_dict[tensor] + + save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"}) + +def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32): modules_to_recurse = [ "self_attn.q_proj", "self_attn.k_proj", @@ -88,12 +230,25 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12): f"divisible by the recurse layers ({recurse_layers})" ) - model_path = Path(snapshot_download(model_name)) + model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth")) # create a new state_dict to store the RRT model weights rrt_model_state_dict = {} - for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="cpu", recurse_layers=recurse_layers): + for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="mps", recurse_layers=recurse_layers): rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() - # split_torch_state_dict_into_shards(...) + # now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff + rrt_lora_state_dict = {} + for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device="mps", recurse_layers=recurse_layers): + rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() + + # combine state dicts into a single state_dict + rrt_model_state_dict.update(rrt_lora_state_dict) + + # save state dict as sharded safetensors to disk using split_torch_state_dict_into_shards + save_state_dict_to_safetensors(rrt_model_state_dict, output_dir) + + +if __name__ == "__main__": + convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=32) diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index 02497ef86..1679a52ed 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -1,8 +1,9 @@ +import logging from typing import Tuple, Optional, Unpack, Callable, Union import torch from torch import nn -from transformers import LlamaConfig, Cache, logger, DynamicCache +from transformers import LlamaConfig, Cache, DynamicCache from transformers.activations import ACT2FN from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast @@ -12,6 +13,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear +logger = logging.getLogger(__name__) class RelaxedRecursiveLlamaConfig(LlamaConfig): """ @@ -25,7 +27,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig): class RelaxedRecursiveLlamaMLP(nn.Module): def __init__(self, config: RelaxedRecursiveLlamaConfig): super().__init__() - recurse_loops = config.num_layers // config.recurse_layers + recurse_loops = config.num_hidden_layers // config.recurse_layers self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size @@ -46,7 +48,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module): def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int): super().__init__() - recurse_loops = config.num_layers // config.recurse_layers + recurse_loops = config.num_hidden_layers // config.recurse_layers self.config = config self.layer_idx = layer_idx self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) @@ -127,7 +129,7 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module): def __init__(self, config: LlamaConfig, layer_idx: int): super().__init__() - recurse_loops = config.num_layers // config.recurse_layers + recurse_loops = config.num_hidden_layers // config.recurse_layers self.hidden_size = config.hidden_size self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx) @@ -185,7 +187,7 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module): class RelaxedRecursiveLlamaModel(LlamaModel): def __init__(self, config): super(LlamaModel, self).__init__(config) - self.recurse_loops = config.num_layers // config.recurse_layers + self.recurse_loops = config.num_hidden_layers // config.recurse_layers self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size