diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 47cc66581..5ad7427f5 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -104,7 +104,7 @@ def low_rank_decomposition( U, S, Vh = torch.linalg.svd(weight.float(), full_matrices=False) - # Distribute S to both to improve numerical precision. + # Distribute S to both to improve numerical precision sqrt_S = torch.sqrt(torch.diag(S[:max_rank])) A = sqrt_S @ Vh[:max_rank, :] # shape: [r, cols] B = U[:, :max_rank] @ sqrt_S # shape: [rows, r] @@ -119,7 +119,7 @@ def get_weight_norm(weight, lora_weight, scaling) -> torch.Tensor: return weight_norm -def decompose_delta_weight(layer_weight, avg_weight, alpha, rank): +def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True): """ Decompose the difference in directions (ΔV) via SVD, and return (magnitudes, L, R). @@ -132,18 +132,21 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank): base_weight = avg_weight.to(device) final_weight = layer_weight.to(device) - delta_first_pass = final_weight - base_weight + delta_for_svd = final_weight - base_weight - delta_for_svd = delta_first_pass - - # 3. Low-rank factorization of the delta direction + # Low-rank factorization of the delta direction lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank) - lora_weight = lora_B @ lora_A + if use_dora: + lora_weight = lora_B @ lora_A + weight_norm = get_weight_norm( + base_weight.to(lora_A.device), lora_weight, scaling + ) + return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu() - weight_norm = get_weight_norm(base_weight.to(lora_A.device), lora_weight, scaling) + # let's rescale the lora weight to have the same magnitude as the base weight - return lora_A.cpu(), lora_B.cpu(), weight_norm.cpu() + return lora_A.cpu(), lora_B.cpu(), None def iter_dora_parameter_weights( @@ -154,9 +157,8 @@ def iter_dora_parameter_weights( rank, device="mps", recurse_layers=12, + use_dora=True, ): - rrt_avg_model_state_dict = {} - # iterate over all parameter weights in the model shards for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): # get the matching module name in modules_to_recurse for the current parameter key @@ -194,11 +196,16 @@ def iter_dora_parameter_weights( f"model.layers.{suffix}.lora_magnitude_vector_list.{loop_idx}" ) lora_a, lora_b, lora_magnitude = decompose_delta_weight( - weight, avg_weight, alpha, rank + weight, + avg_weight, + alpha, + rank, + use_dora=use_dora, ) yield lora_a_key, lora_a yield lora_b_key, lora_b - yield lora_magnitude_key, lora_magnitude + if use_dora: + yield lora_magnitude_key, lora_magnitude def save_state_dict_to_safetensors(state_dict, save_directory): @@ -315,13 +322,13 @@ def convert_llama_to_rrt( # create a new state_dict to store the RRT model weights rrt_model_state_dict = {} - logger.info(f"Calculating average recursive weights...") + logger.info("Calculating average recursive weights...") for key, weight in iter_recursive_parameter_weights( model_path, modules_to_recurse, device=device, recurse_layers=recurse_layers ): rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() - logger.info(f"Calculating decomposed lora diff...") + logger.info("Calculating decomposed lora diff...") # now that we have the average weights, we need to loop over the shards again to calculate the decomposed lora diff rrt_lora_state_dict = {} for key, weight in iter_dora_parameter_weights( @@ -332,6 +339,7 @@ def convert_llama_to_rrt( rank=rank, device=device, recurse_layers=recurse_layers, + use_dora=use_dora, ): rrt_lora_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() @@ -344,8 +352,9 @@ def convert_llama_to_rrt( if __name__ == "__main__": # meta-llama/Llama-3.2-1B has 16 hidden layers + # meta-llama/Llama-3.2-3B has 28 hidden layers convert_llama_to_rrt( - "meta-llama/Llama-3.2-1B", + "meta-llama/Llama-3.2-3B", "/tmp/rrt_model", recurse_layers=4, rank=256, diff --git a/src/axolotl/integrations/rrt/modeling/linear.py b/src/axolotl/integrations/rrt/modeling/linear.py index 38516801e..f4d4b95bb 100644 --- a/src/axolotl/integrations/rrt/modeling/linear.py +++ b/src/axolotl/integrations/rrt/modeling/linear.py @@ -71,6 +71,7 @@ class RelaxedRecursiveDoraLinear(nn.Module): :param loop_idx: :return: """ + eps = 1e-6 w_base = self.weight_base w_base = w_base.to(x.dtype) @@ -78,8 +79,7 @@ class RelaxedRecursiveDoraLinear(nn.Module): lora_B: torch.Tensor = self.lora_B_list[loop_idx] base_out: torch.Tensor = F.linear(x, w_base, self.bias) - - lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) + lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling if self.use_dora: x_eye: torch.Tensor = torch.eye( @@ -98,8 +98,12 @@ class RelaxedRecursiveDoraLinear(nn.Module): 0 ) # shape [1, out_features] - result_dora = ( - scale_factor - 1 - ) * base_out + scale_factor * lora_out * self.scaling + result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out return result_dora - return base_out + lora_out * self.scaling + + # scale the lora norm to prevent gradient explosion + orig_norm = torch.linalg.norm(w_base) + update_norm = torch.linalg.norm(lora_out) + scale = orig_norm / (update_norm + eps) + + return base_out + lora_out * scale