rescale the norm for lora
This commit is contained in:
@@ -104,7 +104,7 @@ def low_rank_decomposition(
|
||||
|
||||
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
||||
|
||||
# Distribute S to both to improve numerical precision.
|
||||
# Distribute S to both to improve numerical precision
|
||||
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
|
||||
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
|
||||
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
|
||||
@@ -119,7 +119,7 @@ def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
|
||||
return weight_norm
|
||||
|
||||
|
||||
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
|
||||
"""
|
||||
Decompose the difference in directions (ΔV) via SVD,
|
||||
and return (magnitudes, L, R).
|
||||
@@ -132,18 +132,21 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
||||
base_weight = avg_weight.to(device)
|
||||
final_weight = layer_weight.to(device)
|
||||
|
||||
delta_first_pass = final_weight - base_weight
|
||||
delta_for_svd = final_weight - base_weight
|
||||
|
||||
delta_for_svd = delta_first_pass
|
||||
|
||||
# 3. Low-rank factorization of the delta direction
|
||||
# Low-rank factorization of the delta direction
|
||||
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
|
||||
|
||||
lora_weight = lora_B @ lora_A
|
||||
if use_dora:
|
||||
lora_weight = lora_B @ lora_A
|
||||
weight_norm = get_weight_norm(
|
||||
base_weight.to(lora_A.device), lora_weight, scaling
|
||||
)
|
||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||
|
||||
weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling)
|
||||
# let's rescale the lora weight to have the same magnitude as the base weight
|
||||
|
||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||
return lora_A.cpu(), lora_B.cpu(), None
|
||||
|
||||
|
||||
def iter_dora_parameter_weights(
|
||||
@@ -154,9 +157,8 @@ def iter_dora_parameter_weights(
|
||||
rank,
|
||||
device="mps",
|
||||
recurse_layers=12,
|
||||
use_dora=True,
|
||||
):
|
||||
rrt_avg_model_state_dict = {}
|
||||
|
||||
# iterate over all parameter weights in the model shards
|
||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||
# get the matching module name in modules_to_recurse for the current parameter key
|
||||
@@ -194,11 +196,16 @@ def iter_dora_parameter_weights(
|
||||
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||
)
|
||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
|
||||
weight, avg_weight, alpha, rank
|
||||
weight,
|
||||
avg_weight,
|
||||
alpha,
|
||||
rank,
|
||||
use_dora=use_dora,
|
||||
)
|
||||
yield lora_a_key, lora_a
|
||||
yield lora_b_key, lora_b
|
||||
yield lora_magnitude_key, lora_magnitude
|
||||
if use_dora:
|
||||
yield lora_magnitude_key, lora_magnitude
|
||||
|
||||
|
||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||
@@ -315,13 +322,13 @@ def convert_llama_to_rrt(
|
||||
# create a new state_dict to store the RRT model weights
|
||||
rrt_model_state_dict = {}
|
||||
|
||||
logger.info(f"Calculating average recursive weights...")
|
||||
logger.info("Calculating average recursive weights...")
|
||||
for key, weight in iter_recursive_parameter_weights(
|
||||
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
|
||||
):
|
||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
logger.info(f"Calculating decomposed lora diff...")
|
||||
logger.info("Calculating decomposed lora diff...")
|
||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||
rrt_lora_state_dict = {}
|
||||
for key, weight in iter_dora_parameter_weights(
|
||||
@@ -332,6 +339,7 @@ def convert_llama_to_rrt(
|
||||
rank=rank,
|
||||
device=device,
|
||||
recurse_layers=recurse_layers,
|
||||
use_dora=use_dora,
|
||||
):
|
||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||
|
||||
@@ -344,8 +352,9 @@ def convert_llama_to_rrt(
|
||||
|
||||
if __name__ == "__main__":
|
||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||
# meta-llama/Llama-3.2-3B has 28 hidden layers
|
||||
convert_llama_to_rrt(
|
||||
"meta-llama/Llama-3.2-1B",
|
||||
"meta-llama/Llama-3.2-3B",
|
||||
"/tmp/rrt_model",
|
||||
recurse_layers=4,
|
||||
rank=256,
|
||||
|
||||
@@ -71,6 +71,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
:param loop_idx:
|
||||
:return:
|
||||
"""
|
||||
eps = 1e-6
|
||||
w_base = self.weight_base
|
||||
w_base = w_base.to(x.dtype)
|
||||
|
||||
@@ -78,8 +79,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
|
||||
|
||||
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
||||
|
||||
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B)
|
||||
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
|
||||
|
||||
if self.use_dora:
|
||||
x_eye: torch.Tensor = torch.eye(
|
||||
@@ -98,8 +98,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
||||
0
|
||||
) # shape [1, out_features]
|
||||
|
||||
result_dora = (
|
||||
scale_factor - 1
|
||||
) * base_out + scale_factor * lora_out * self.scaling
|
||||
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
||||
return result_dora
|
||||
return base_out + lora_out * self.scaling
|
||||
|
||||
# scale the lora norm to prevent gradient explosion
|
||||
orig_norm = torch.linalg.norm(w_base)
|
||||
update_norm = torch.linalg.norm(lora_out)
|
||||
scale = orig_norm / (update_norm + eps)
|
||||
|
||||
return base_out + lora_out * scale
|
||||
|
||||
Reference in New Issue
Block a user