rescale the norm for lora
This commit is contained in:
@@ -104,7 +104,7 @@ def low_rank_decomposition(
|
|||||||
|
|
||||||
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False)
|
||||||
|
|
||||||
# Distribute S to both to improve numerical precision.
|
# Distribute S to both to improve numerical precision
|
||||||
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
|
sqrt_S = torch.sqrt(torch.diag(S[:max_rank]))
|
||||||
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
|
A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols]
|
||||||
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
|
B = U[:, :max_rank] @ sqrt_S # shape: [rows, r]
|
||||||
@@ -119,7 +119,7 @@ def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor:
|
|||||||
return weight_norm
|
return weight_norm
|
||||||
|
|
||||||
|
|
||||||
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True):
|
||||||
"""
|
"""
|
||||||
Decompose the difference in directions (ΔV) via SVD,
|
Decompose the difference in directions (ΔV) via SVD,
|
||||||
and return (magnitudes, L, R).
|
and return (magnitudes, L, R).
|
||||||
@@ -132,18 +132,21 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank):
|
|||||||
base_weight = avg_weight.to(device)
|
base_weight = avg_weight.to(device)
|
||||||
final_weight = layer_weight.to(device)
|
final_weight = layer_weight.to(device)
|
||||||
|
|
||||||
delta_first_pass = final_weight - base_weight
|
delta_for_svd = final_weight - base_weight
|
||||||
|
|
||||||
delta_for_svd = delta_first_pass
|
# Low-rank factorization of the delta direction
|
||||||
|
|
||||||
# 3. Low-rank factorization of the delta direction
|
|
||||||
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
|
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
|
||||||
|
|
||||||
lora_weight = lora_B @ lora_A
|
if use_dora:
|
||||||
|
lora_weight = lora_B @ lora_A
|
||||||
|
weight_norm = get_weight_norm(
|
||||||
|
base_weight.to(lora_A.device), lora_weight, scaling
|
||||||
|
)
|
||||||
|
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
||||||
|
|
||||||
weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling)
|
# let's rescale the lora weight to have the same magnitude as the base weight
|
||||||
|
|
||||||
return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu()
|
return lora_A.cpu(), lora_B.cpu(), None
|
||||||
|
|
||||||
|
|
||||||
def iter_dora_parameter_weights(
|
def iter_dora_parameter_weights(
|
||||||
@@ -154,9 +157,8 @@ def iter_dora_parameter_weights(
|
|||||||
rank,
|
rank,
|
||||||
device="mps",
|
device="mps",
|
||||||
recurse_layers=12,
|
recurse_layers=12,
|
||||||
|
use_dora=True,
|
||||||
):
|
):
|
||||||
rrt_avg_model_state_dict = {}
|
|
||||||
|
|
||||||
# iterate over all parameter weights in the model shards
|
# iterate over all parameter weights in the model shards
|
||||||
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
|
||||||
# get the matching module name in modules_to_recurse for the current parameter key
|
# get the matching module name in modules_to_recurse for the current parameter key
|
||||||
@@ -194,11 +196,16 @@ def iter_dora_parameter_weights(
|
|||||||
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}"
|
||||||
)
|
)
|
||||||
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
|
lora_a, lora_b, lora_magnitude = decompose_delta_weight(
|
||||||
weight, avg_weight, alpha, rank
|
weight,
|
||||||
|
avg_weight,
|
||||||
|
alpha,
|
||||||
|
rank,
|
||||||
|
use_dora=use_dora,
|
||||||
)
|
)
|
||||||
yield lora_a_key, lora_a
|
yield lora_a_key, lora_a
|
||||||
yield lora_b_key, lora_b
|
yield lora_b_key, lora_b
|
||||||
yield lora_magnitude_key, lora_magnitude
|
if use_dora:
|
||||||
|
yield lora_magnitude_key, lora_magnitude
|
||||||
|
|
||||||
|
|
||||||
def save_state_dict_to_safetensors(state_dict, save_directory):
|
def save_state_dict_to_safetensors(state_dict, save_directory):
|
||||||
@@ -315,13 +322,13 @@ def convert_llama_to_rrt(
|
|||||||
# create a new state_dict to store the RRT model weights
|
# create a new state_dict to store the RRT model weights
|
||||||
rrt_model_state_dict = {}
|
rrt_model_state_dict = {}
|
||||||
|
|
||||||
logger.info(f"Calculating average recursive weights...")
|
logger.info("Calculating average recursive weights...")
|
||||||
for key, weight in iter_recursive_parameter_weights(
|
for key, weight in iter_recursive_parameter_weights(
|
||||||
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
|
model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers
|
||||||
):
|
):
|
||||||
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
logger.info(f"Calculating decomposed lora diff...")
|
logger.info("Calculating decomposed lora diff...")
|
||||||
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
# now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff
|
||||||
rrt_lora_state_dict = {}
|
rrt_lora_state_dict = {}
|
||||||
for key, weight in iter_dora_parameter_weights(
|
for key, weight in iter_dora_parameter_weights(
|
||||||
@@ -332,6 +339,7 @@ def convert_llama_to_rrt(
|
|||||||
rank=rank,
|
rank=rank,
|
||||||
device=device,
|
device=device,
|
||||||
recurse_layers=recurse_layers,
|
recurse_layers=recurse_layers,
|
||||||
|
use_dora=use_dora,
|
||||||
):
|
):
|
||||||
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
|
||||||
|
|
||||||
@@ -344,8 +352,9 @@ def convert_llama_to_rrt(
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||||
|
# meta-llama/Llama-3.2-3B has 28 hidden layers
|
||||||
convert_llama_to_rrt(
|
convert_llama_to_rrt(
|
||||||
"meta-llama/Llama-3.2-1B",
|
"meta-llama/Llama-3.2-3B",
|
||||||
"/tmp/rrt_model",
|
"/tmp/rrt_model",
|
||||||
recurse_layers=4,
|
recurse_layers=4,
|
||||||
rank=256,
|
rank=256,
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
:param loop_idx:
|
:param loop_idx:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
eps = 1e-6
|
||||||
w_base = self.weight_base
|
w_base = self.weight_base
|
||||||
w_base = w_base.to(x.dtype)
|
w_base = w_base.to(x.dtype)
|
||||||
|
|
||||||
@@ -78,8 +79,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
|
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
|
||||||
|
|
||||||
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
||||||
|
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling
|
||||||
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B)
|
|
||||||
|
|
||||||
if self.use_dora:
|
if self.use_dora:
|
||||||
x_eye: torch.Tensor = torch.eye(
|
x_eye: torch.Tensor = torch.eye(
|
||||||
@@ -98,8 +98,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
0
|
0
|
||||||
) # shape [1, out_features]
|
) # shape [1, out_features]
|
||||||
|
|
||||||
result_dora = (
|
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
||||||
scale_factor - 1
|
|
||||||
) * base_out + scale_factor * lora_out * self.scaling
|
|
||||||
return result_dora
|
return result_dora
|
||||||
return base_out + lora_out * self.scaling
|
|
||||||
|
# scale the lora norm to prevent gradient explosion
|
||||||
|
orig_norm = torch.linalg.norm(w_base)
|
||||||
|
update_norm = torch.linalg.norm(lora_out)
|
||||||
|
scale = orig_norm / (update_norm + eps)
|
||||||
|
|
||||||
|
return base_out + lora_out * scale
|
||||||
|
|||||||
Reference in New Issue
Block a user