wip conversion cli
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import re
|
||||
from typing import Tuple
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
|
||||
from safetensors.torch import save_file
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoConfig
|
||||
from transformers.utils import SAFE_WEIGHTS_NAME
|
||||
|
||||
|
||||
def extract_layer_number(key):
|
||||
@@ -14,7 +18,7 @@ def extract_layer_number(key):
|
||||
return int(match.group(1)) if match else None
|
||||
|
||||
|
||||
def iter_parameter_weights(model_path, device="cpu"):
|
||||
def iter_parameter_weights(model_path, device="mps"):
|
||||
"""
|
||||
iterator over parameter weights in the model shards
|
||||
|
||||
@@ -33,7 +37,7 @@ def iter_parameter_weights(model_path, device="cpu"):
|
||||
weight = f.get_tensor(key)
|
||||
yield key, weight, layer_idx
|
||||
|
||||
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="cpu", recurse_layers=12):
|
||||
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12):
|
||||
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
||||
# to avoid precision loss when averaging weights across layers
|
||||
rrt_avg_model_state_dict = {}
|
||||
@@ -46,30 +50,168 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
||||
None
|
||||
)
|
||||
if matched_module_name is None:
|
||||
if "input_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
yield
|
||||
else:
|
||||
yield key, weight
|
||||
continue
|
||||
|
||||
recurse_idx = layer_idx % recurse_layers
|
||||
suffix = f"{recurse_idx}.{matched_module_name}"
|
||||
prefix = f"model.layers.{suffix}."
|
||||
if rrt_avg_model_state_dict.get(suffix) is None:
|
||||
# setup as storage for suffix with torch.stack
|
||||
rrt_avg_model_state_dict[suffix] = torch.stack([weight.to(torch.float32).detach().cpu()])
|
||||
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
|
||||
else:
|
||||
rrt_avg_model_state_dict[suffix] = torch.cat([rrt_avg_model_state_dict[suffix], weight.to(torch.float32).detach().cpu()])
|
||||
rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu())
|
||||
|
||||
for module_name in modules_to_recurse:
|
||||
for recurse_idx in range(recurse_layers):
|
||||
suffix = f"{recurse_idx}.{module_name}"
|
||||
prefix = f"model.layers.{suffix}."
|
||||
avg_weight = rrt_avg_model_state_dict[suffix].mean(dim=0)
|
||||
yield f"{prefix}.weight", avg_weight
|
||||
prefix = f"model.layers.{suffix}"
|
||||
avg_weight = torch.stack(rrt_avg_model_state_dict[suffix]).mean(dim=0)
|
||||
yield f"{prefix}.weight_base", avg_weight
|
||||
|
||||
# compute the decomposed lora diff from the weight base to the actual weight for each module
|
||||
|
||||
def low_rank_decomposition(
|
||||
weight: torch.Tensor, max_rank: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Decompose a 2D matrix into low-rank matrices L and R using SVD.
|
||||
|
||||
:param weight: The matrix to decompose, of shape (H, W)
|
||||
:param max_rank: The maximum rank of the decomposition
|
||||
:return: A tuple of tensors (L, R)
|
||||
"""
|
||||
assert (
|
||||
weight.dim() == 2
|
||||
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
|
||||
assert (
|
||||
max_rank >= 1
|
||||
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
|
||||
|
||||
dtype = weight.dtype
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
||||
|
||||
final_rank = min(min(weight.shape), max_rank)
|
||||
|
||||
# Distribute S to both to improve numerical precision.
|
||||
sqrt_S = torch.sqrt(torch.diag(S[:final_rank]))
|
||||
L = sqrt_S @ Vh[:final_rank, :]
|
||||
R = U[:, :final_rank] @ sqrt_S
|
||||
|
||||
return L.to(dtype), R.to(dtype)
|
||||
|
||||
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||
"""
|
||||
Decompose the difference in directions (ΔV) via SVD,
|
||||
and return (magnitudes, L, R).
|
||||
"""
|
||||
device = "cuda" if torch.cuda.is_available() else "mps"
|
||||
|
||||
base_weight = avg_weight.to(device)
|
||||
finetuned_weight = layer_weight.to(device)
|
||||
|
||||
# 1. Compute column norms and directions
|
||||
# (shape: base_norms, finetuned_norms => (k,))
|
||||
base_norms = torch.norm(base_weight, dim=0) + 1e-9
|
||||
finetuned_norms = torch.norm(finetuned_weight, dim=0) + 1e-9
|
||||
|
||||
# shape (d, k)
|
||||
base_dir = base_weight / base_norms
|
||||
finetuned_dir = finetuned_weight / finetuned_norms
|
||||
|
||||
# 2. Delta direction
|
||||
delta_dir = finetuned_dir - base_dir
|
||||
|
||||
# 3. Low-rank factorization of the delta direction
|
||||
A, B = low_rank_decomposition(delta_dir, rank)
|
||||
# The final magnitudes are just finetuned_norms
|
||||
return A.cpu(), B.cpu(), finetuned_norms.cpu()
|
||||
|
||||
|
||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12):
|
||||
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
|
||||
rrt_avg_model_state_dict = {}
|
||||
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
matched_module_name = next(
|
||||
(module for module in modules_to_recurse if module in key),
|
||||
None
|
||||
)
|
||||
if matched_module_name is None:
|
||||
if "input_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}"
|
||||
yield layernorm_key, weight
|
||||
elif "post_attention_layernorm" in key:
|
||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}"
|
||||
yield layernorm_key, weight
|
||||
else:
|
||||
yield key, weight
|
||||
continue
|
||||
|
||||
# figure out the base weight layer for this key
|
||||
loop_idx = layer_idx // recurse_layers
|
||||
layer_idx = layer_idx % recurse_layers
|
||||
suffix = f"{layer_idx}.{matched_module_name}"
|
||||
prefix = f"model.layers.{suffix}.weight_base"
|
||||
avg_weight = avg_recursive_weights[prefix]
|
||||
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
||||
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
||||
lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank)
|
||||
yield lora_a_key, lora_a
|
||||
yield lora_b_key, lora_b
|
||||
yield lora_magnitude_key, lora_magnitude
|
||||
|
||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
weights_name = SAFE_WEIGHTS_NAME
|
||||
|
||||
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||
state_dict_split = split_torch_state_dict_into_shards(
|
||||
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
|
||||
)
|
||||
# Save index if sharded
|
||||
index = None
|
||||
if state_dict_split.is_sharded:
|
||||
index = {
|
||||
"metadata": state_dict_split.metadata,
|
||||
"weight_map": state_dict_split.tensor_to_filename,
|
||||
}
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||
|
||||
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||
|
||||
if (
|
||||
filename.startswith(weights_no_suffix)
|
||||
and os.path.isfile(full_filename)
|
||||
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||
and reg.fullmatch(filename_no_suffix) is not None
|
||||
):
|
||||
os.remove(full_filename)
|
||||
|
||||
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||
for shard_file, tensors in filename_to_tensors:
|
||||
shard = {}
|
||||
for tensor in tensors:
|
||||
shard[tensor] = state_dict[tensor].contiguous()
|
||||
del state_dict[tensor]
|
||||
|
||||
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
||||
|
||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32):
|
||||
modules_to_recurse = [
|
||||
"self_attn.q_proj",
|
||||
"self_attn.k_proj",
|
||||
@@ -88,12 +230,25 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12):
|
||||
f"divisible by the recurse layers ({recurse_layers})"
|
||||
)
|
||||
|
||||
model_path = Path(snapshot_download(model_name))
|
||||
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||
|
||||
# create a new state_dict to store the RRT model weights
|
||||
rrt_model_state_dict = {}
|
||||
|
||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="cpu", recurse_layers=recurse_layers):
|
||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="mps", recurse_layers=recurse_layers):
|
||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
# split_torch_state_dict_into_shards(...)
|
||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||
rrt_lora_state_dict = {}
|
||||
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device="mps", recurse_layers=recurse_layers):
|
||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
# combine state dicts into a single state_dict
|
||||
rrt_model_state_dict.update(rrt_lora_state_dict)
|
||||
|
||||
# save state dict as sharded safetensors to disk using split_torch_state_dict_into_shards
|
||||
save_state_dict_to_safetensors(rrt_model_state_dict, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=32)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import logging
|
||||
from typing import Tuple, Optional, Unpack, Callable, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig, Cache, logger, DynamicCache
|
||||
from transformers import LlamaConfig, Cache, DynamicCache
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
@@ -12,6 +13,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager
|
||||
|
||||
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
"""
|
||||
@@ -25,7 +27,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
||||
super().__init__()
|
||||
recurse_loops = config.num_layers // config.recurse_layers
|
||||
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
@@ -46,7 +48,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
recurse_loops = config.num_layers // config.recurse_layers
|
||||
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||
@@ -127,7 +129,7 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
recurse_loops = config.num_layers // config.recurse_layers
|
||||
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx)
|
||||
@@ -185,7 +187,7 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
|
||||
class RelaxedRecursiveLlamaModel(LlamaModel):
|
||||
def __init__(self, config):
|
||||
super(LlamaModel, self).__init__(config)
|
||||
self.recurse_loops = config.num_layers // config.recurse_layers
|
||||
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
|
||||
Reference in New Issue
Block a user