rescale the norm for lora

This commit is contained in:
Wing Lian
2025-01-22 10:02:29 -05:00
parent fa5efbf235
commit 0af78a9882
2 changed files with 35 additions and 22 deletions

View File

@@ -104,7 +104,7 @@ def low_rank_decomposition(
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
# Distribute S to both to improve numerical precision.
# Distribute S to both to improve numerical precision
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
@@ -119,7 +119,7 @@ def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
return weight_norm
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
"""
Decompose the difference in directions (ΔV) via SVD,
and return (magnitudes, L, R).
@@ -132,18 +132,21 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
base_weight = avg_weight.to(device)
final_weight = layer_weight.to(device)
delta_first_pass = final_weight - base_weight
delta_for_svd = final_weight - base_weight
delta_for_svd = delta_first_pass
# 3. Low-rank factorization of the delta direction
# Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
lora_weight = lora_B @ lora_A
if use_dora:
lora_weight = lora_B @ lora_A
weight_norm = get_weight_norm(
base_weight.to(lora_A.device), lora_weight, scaling
)
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling)
# let's rescale the lora weight to have the same magnitude as the base weight
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
return lora_A.cpu(), lora_B.cpu(), None
def iter_dora_parameter_weights(
@@ -154,9 +157,8 @@ def iter_dora_parameter_weights(
rank,
device="mps",
recurse_layers=12,
use_dora=True,
):
rrt_avg_model_state_dict = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
# get the matching module name in modules_to_recurse for the current parameter key
@@ -194,11 +196,16 @@ def iter_dora_parameter_weights(
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
)
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
weight, avg_weight, alpha, rank
weight,
avg_weight,
alpha,
rank,
use_dora=use_dora,
)
yield lora_a_key, lora_a
yield lora_b_key, lora_b
yield lora_magnitude_key, lora_magnitude
if use_dora:
yield lora_magnitude_key, lora_magnitude
def save_state_dict_to_safetensors(state_dict, save_directory):
@@ -315,13 +322,13 @@ def convert_llama_to_rrt(
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
logger.info(f"Calculating average recursive weights...")
logger.info("Calculating average recursive weights...")
for key, weight in iter_recursive_parameter_weights(
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
logger.info(f"Calculating decomposed lora diff...")
logger.info("Calculating decomposed lora diff...")
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
rrt_lora_state_dict = {}
for key, weight in iter_dora_parameter_weights(
@@ -332,6 +339,7 @@ def convert_llama_to_rrt(
rank=rank,
device=device,
recurse_layers=recurse_layers,
use_dora=use_dora,
):
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
@@ -344,8 +352,9 @@ def convert_llama_to_rrt(
if __name__ == "__main__":
# meta-llama/Llama-3.2-1B has 16 hidden layers
# meta-llama/Llama-3.2-3B has 28 hidden layers
convert_llama_to_rrt(
"meta-llama/Llama-3.2-1B",
"meta-llama/Llama-3.2-3B",
"/tmp/rrt_model",
recurse_layers=4,
rank=256,

View File

@@ -71,6 +71,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
:param loop_idx:
:return:
"""
eps = 1e-6
w_base = self.weight_base
w_base = w_base.to(x.dtype)
@@ -78,8 +79,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
if self.use_dora:
x_eye: torch.Tensor = torch.eye(
@@ -98,8 +98,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
0
) # shape [1, out_features]
result_dora = (
scale_factor - 1
) * base_out + scale_factor * lora_out * self.scaling
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora
return base_out + lora_out * self.scaling
# scale the lora norm to prevent gradient explosion
orig_norm = torch.linalg.norm(w_base)
update_norm = torch.linalg.norm(lora_out)
scale = orig_norm / (update_norm + eps)
return base_out + lora_out * scale