wip conversion cli
This commit is contained in:
@@ -1,11 +1,15 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import re
|
from typing import Tuple
|
||||||
|
|
||||||
import safetensors
|
import safetensors
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download, split_torch_state_dict_into_shards
|
||||||
|
from safetensors.torch import save_file
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
def extract_layer_number(key):
|
def extract_layer_number(key):
|
||||||
@@ -14,7 +18,7 @@ def extract_layer_number(key):
|
|||||||
return int(match.group(1)) if match else None
|
return int(match.group(1)) if match else None
|
||||||
|
|
||||||
|
|
||||||
def iter_parameter_weights(model_path, device="cpu"):
|
def iter_parameter_weights(model_path, device="mps"):
|
||||||
"""
|
"""
|
||||||
iterator over parameter weights in the model shards
|
iterator over parameter weights in the model shards
|
||||||
|
|
||||||
@@ -33,7 +37,7 @@ def iter_parameter_weights(model_path, device="cpu"):
|
|||||||
weight = f.get_tensor(key)
|
weight = f.get_tensor(key)
|
||||||
yield key, weight, layer_idx
|
yield key, weight, layer_idx
|
||||||
|
|
||||||
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="cpu", recurse_layers=12):
|
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="mps", recurse_layers=12):
|
||||||
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
|
||||||
# to avoid precision loss when averaging weights across layers
|
# to avoid precision loss when averaging weights across layers
|
||||||
rrt_avg_model_state_dict = {}
|
rrt_avg_model_state_dict = {}
|
||||||
@@ -46,30 +50,168 @@ def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str],
|
|||||||
None
|
None
|
||||||
)
|
)
|
||||||
if matched_module_name is None:
|
if matched_module_name is None:
|
||||||
if "input_layernorm" in key:
|
continue
|
||||||
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
|
||||||
yield
|
|
||||||
else:
|
|
||||||
yield key, weight
|
|
||||||
|
|
||||||
recurse_idx = layer_idx % recurse_layers
|
recurse_idx = layer_idx % recurse_layers
|
||||||
suffix = f"{recurse_idx}.{matched_module_name}"
|
suffix = f"{recurse_idx}.{matched_module_name}"
|
||||||
prefix = f"model.layers.{suffix}."
|
|
||||||
if rrt_avg_model_state_dict.get(suffix) is None:
|
if rrt_avg_model_state_dict.get(suffix) is None:
|
||||||
# setup as storage for suffix with torch.stack
|
# setup as storage for suffix with torch.stack
|
||||||
rrt_avg_model_state_dict[suffix] = torch.stack([weight.to(torch.float32).detach().cpu()])
|
rrt_avg_model_state_dict[suffix] = [weight.to(torch.float32).detach().cpu()]
|
||||||
else:
|
else:
|
||||||
rrt_avg_model_state_dict[suffix] = torch.cat([rrt_avg_model_state_dict[suffix], weight.to(torch.float32).detach().cpu()])
|
rrt_avg_model_state_dict[suffix].append(weight.to(torch.float32).detach().cpu())
|
||||||
|
|
||||||
for module_name in modules_to_recurse:
|
for module_name in modules_to_recurse:
|
||||||
for recurse_idx in range(recurse_layers):
|
for recurse_idx in range(recurse_layers):
|
||||||
suffix = f"{recurse_idx}.{module_name}"
|
suffix = f"{recurse_idx}.{module_name}"
|
||||||
prefix = f"model.layers.{suffix}."
|
prefix = f"model.layers.{suffix}"
|
||||||
avg_weight = rrt_avg_model_state_dict[suffix].mean(dim=0)
|
avg_weight = torch.stack(rrt_avg_model_state_dict[suffix]).mean(dim=0)
|
||||||
yield f"{prefix}.weight", avg_weight
|
yield f"{prefix}.weight_base", avg_weight
|
||||||
|
|
||||||
|
# compute the decomposed lora diff from the weight base to the actual weight for each module
|
||||||
|
|
||||||
|
def low_rank_decomposition(
|
||||||
|
weight: torch.Tensor, max_rank: int
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Decompose a 2D matrix into low-rank matrices L and R using SVD.
|
||||||
|
|
||||||
|
:param weight: The matrix to decompose, of shape (H, W)
|
||||||
|
:param max_rank: The maximum rank of the decomposition
|
||||||
|
:return: A tuple of tensors (L, R)
|
||||||
|
"""
|
||||||
|
assert (
|
||||||
|
weight.dim() == 2
|
||||||
|
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
|
||||||
|
assert (
|
||||||
|
max_rank >= 1
|
||||||
|
), f"Maximum rank must be a positive integer, but input max_rank={max_rank}."
|
||||||
|
|
||||||
|
dtype = weight.dtype
|
||||||
|
|
||||||
|
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
||||||
|
|
||||||
|
final_rank = min(min(weight.shape), max_rank)
|
||||||
|
|
||||||
|
# Distribute S to both to improve numerical precision.
|
||||||
|
sqrt_S = torch.sqrt(torch.diag(S[:final_rank]))
|
||||||
|
L = sqrt_S @ Vh[:final_rank, :]
|
||||||
|
R = U[:, :final_rank] @ sqrt_S
|
||||||
|
|
||||||
|
return L.to(dtype), R.to(dtype)
|
||||||
|
|
||||||
|
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||||
|
"""
|
||||||
|
Decompose the difference in directions (ΔV) via SVD,
|
||||||
|
and return (magnitudes, L, R).
|
||||||
|
"""
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "mps"
|
||||||
|
|
||||||
|
base_weight = avg_weight.to(device)
|
||||||
|
finetuned_weight = layer_weight.to(device)
|
||||||
|
|
||||||
|
# 1. Compute column norms and directions
|
||||||
|
# (shape: base_norms, finetuned_norms => (k,))
|
||||||
|
base_norms = torch.norm(base_weight, dim=0) + 1e-9
|
||||||
|
finetuned_norms = torch.norm(finetuned_weight, dim=0) + 1e-9
|
||||||
|
|
||||||
|
# shape (d, k)
|
||||||
|
base_dir = base_weight / base_norms
|
||||||
|
finetuned_dir = finetuned_weight / finetuned_norms
|
||||||
|
|
||||||
|
# 2. Delta direction
|
||||||
|
delta_dir = finetuned_dir - base_dir
|
||||||
|
|
||||||
|
# 3. Low-rank factorization of the delta direction
|
||||||
|
A, B = low_rank_decomposition(delta_dir, rank)
|
||||||
|
# The final magnitudes are just finetuned_norms
|
||||||
|
return A.cpu(), B.cpu(), finetuned_norms.cpu()
|
||||||
|
|
||||||
|
|
||||||
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12):
|
def iter_dora_parameter_weights(model_path, avg_recursive_weights, modules_to_recurse: list[str], alpha, rank, device="mps", recurse_layers=12):
|
||||||
|
rrt_avg_model_state_dict = {}
|
||||||
|
|
||||||
|
# iterate over all parameter weights in the model shards
|
||||||
|
for key, weight, layer_idx in iter_parameter_weights(model_path):
|
||||||
|
# get the matching module name in modules_to_recurse for the current parameter key
|
||||||
|
matched_module_name = next(
|
||||||
|
(module for module in modules_to_recurse if module in key),
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if matched_module_name is None:
|
||||||
|
if "input_layernorm" in key:
|
||||||
|
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||||
|
loop_idx = layer_idx // recurse_layers
|
||||||
|
layer_idx = layer_idx % recurse_layers
|
||||||
|
layernorm_key = f"model.layers.{layer_idx}.input_layernorm_list.{loop_idx}"
|
||||||
|
yield layernorm_key, weight
|
||||||
|
elif "post_attention_layernorm" in key:
|
||||||
|
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
|
||||||
|
loop_idx = layer_idx // recurse_layers
|
||||||
|
layer_idx = layer_idx % recurse_layers
|
||||||
|
layernorm_key = f"model.layers.{layer_idx}.post_attention_layernorm_list.{loop_idx}"
|
||||||
|
yield layernorm_key, weight
|
||||||
|
else:
|
||||||
|
yield key, weight
|
||||||
|
continue
|
||||||
|
|
||||||
|
# figure out the base weight layer for this key
|
||||||
|
loop_idx = layer_idx // recurse_layers
|
||||||
|
layer_idx = layer_idx % recurse_layers
|
||||||
|
suffix = f"{layer_idx}.{matched_module_name}"
|
||||||
|
prefix = f"model.layers.{suffix}.weight_base"
|
||||||
|
avg_weight = avg_recursive_weights[prefix]
|
||||||
|
lora_a_key = f"model.layers.{suffix}.lora_A_list.{loop_idx}"
|
||||||
|
lora_b_key = f"model.layers.{suffix}.lora_B_list.{loop_idx}"
|
||||||
|
lora_magnitude_key = f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||||
|
lora_a, lora_b, lora_magnitude = decompose_delta_weight(weight, avg_weight, alpha, rank)
|
||||||
|
yield lora_a_key, lora_a
|
||||||
|
yield lora_b_key, lora_b
|
||||||
|
yield lora_magnitude_key, lora_magnitude
|
||||||
|
|
||||||
|
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||||
|
weights_name = SAFE_WEIGHTS_NAME
|
||||||
|
|
||||||
|
filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
|
||||||
|
state_dict_split = split_torch_state_dict_into_shards(
|
||||||
|
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
|
||||||
|
)
|
||||||
|
# Save index if sharded
|
||||||
|
index = None
|
||||||
|
if state_dict_split.is_sharded:
|
||||||
|
index = {
|
||||||
|
"metadata": state_dict_split.metadata,
|
||||||
|
"weight_map": state_dict_split.tensor_to_filename,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Clean the folder from a previous save
|
||||||
|
for filename in os.listdir(save_directory):
|
||||||
|
full_filename = os.path.join(save_directory, filename)
|
||||||
|
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||||
|
# in distributed settings to avoid race conditions.
|
||||||
|
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
|
||||||
|
|
||||||
|
# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
|
||||||
|
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
|
||||||
|
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
filename.startswith(weights_no_suffix)
|
||||||
|
and os.path.isfile(full_filename)
|
||||||
|
and filename not in state_dict_split.filename_to_tensors.keys()
|
||||||
|
and reg.fullmatch(filename_no_suffix) is not None
|
||||||
|
):
|
||||||
|
os.remove(full_filename)
|
||||||
|
|
||||||
|
filename_to_tensors = state_dict_split.filename_to_tensors.items()
|
||||||
|
for shard_file, tensors in filename_to_tensors:
|
||||||
|
shard = {}
|
||||||
|
for tensor in tensors:
|
||||||
|
shard[tensor] = state_dict[tensor].contiguous()
|
||||||
|
del state_dict[tensor]
|
||||||
|
|
||||||
|
save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
|
||||||
|
|
||||||
|
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=32):
|
||||||
modules_to_recurse = [
|
modules_to_recurse = [
|
||||||
"self_attn.q_proj",
|
"self_attn.q_proj",
|
||||||
"self_attn.k_proj",
|
"self_attn.k_proj",
|
||||||
@@ -88,12 +230,25 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12):
|
|||||||
f"divisible by the recurse layers ({recurse_layers})"
|
f"divisible by the recurse layers ({recurse_layers})"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_path = Path(snapshot_download(model_name))
|
model_path = Path(snapshot_download(model_name, ignore_patterns="*.pth"))
|
||||||
|
|
||||||
# create a new state_dict to store the RRT model weights
|
# create a new state_dict to store the RRT model weights
|
||||||
rrt_model_state_dict = {}
|
rrt_model_state_dict = {}
|
||||||
|
|
||||||
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="cpu", recurse_layers=recurse_layers):
|
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="mps", recurse_layers=recurse_layers):
|
||||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
# split_torch_state_dict_into_shards(...)
|
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||||
|
rrt_lora_state_dict = {}
|
||||||
|
for key, weight in iter_dora_parameter_weights(model_path, rrt_model_state_dict, modules_to_recurse, alpha=32, rank=rank, device="mps", recurse_layers=recurse_layers):
|
||||||
|
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
|
# combine state dicts into a single state_dict
|
||||||
|
rrt_model_state_dict.update(rrt_lora_state_dict)
|
||||||
|
|
||||||
|
# save state dict as sharded safetensors to disk using split_torch_state_dict_into_shards
|
||||||
|
save_state_dict_to_safetensors(rrt_model_state_dict, output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=32)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
|
import logging
|
||||||
from typing import Tuple, Optional, Unpack, Callable, Union
|
from typing import Tuple, Optional, Unpack, Callable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import LlamaConfig, Cache, logger, DynamicCache
|
from transformers import LlamaConfig, Cache, DynamicCache
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
@@ -12,6 +13,7 @@ from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager
|
|||||||
|
|
||||||
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
||||||
"""
|
"""
|
||||||
@@ -25,7 +27,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
|||||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||||
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
def __init__(self, config: RelaxedRecursiveLlamaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
recurse_loops = config.num_layers // config.recurse_layers
|
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
@@ -46,7 +48,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
|
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
recurse_loops = config.num_layers // config.recurse_layers
|
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
|
||||||
@@ -127,7 +129,7 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
def __init__(self, config: LlamaConfig, layer_idx: int):
|
def __init__(self, config: LlamaConfig, layer_idx: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
recurse_loops = config.num_layers // config.recurse_layers
|
recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx)
|
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx)
|
||||||
@@ -185,7 +187,7 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
|
|||||||
class RelaxedRecursiveLlamaModel(LlamaModel):
|
class RelaxedRecursiveLlamaModel(LlamaModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(LlamaModel, self).__init__(config)
|
super(LlamaModel, self).__init__(config)
|
||||||
self.recurse_loops = config.num_layers // config.recurse_layers
|
self.recurse_loops = config.num_hidden_layers // config.recurse_layers
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user