From e3393042e57f84a8e95a26607fda7fd023d12e75 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 21 Jan 2025 12:30:00 -0500 Subject: [PATCH] hopefully fix the lora/dora logic --- src/axolotl/integrations/rrt/cli/convert.py | 10 +++++- .../integrations/rrt/modeling/linear.py | 35 ++++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 50c06b988..7aceb5175 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -264,7 +264,13 @@ def save_state_dict_to_safetensors(state_dict, save_directory): def convert_llama_to_rrt( - model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None + model_name, + output_dir, + recurse_layers: int = 12, + rank=32, + alpha=32, + device=None, + use_dora=True, ): if not device: if torch.backends.mps.is_available(): @@ -299,6 +305,7 @@ def convert_llama_to_rrt( "recurse_layers": recurse_layers, "rank": rank, "alpha": alpha, + "use_dora": use_dora, } ) config.save_pretrained(output_dir) @@ -343,4 +350,5 @@ if __name__ == "__main__": recurse_layers=4, rank=256, alpha=512, + use_dora=False, ) diff --git a/src/axolotl/integrations/rrt/modeling/linear.py b/src/axolotl/integrations/rrt/modeling/linear.py index 8cc86094c..009fafc4b 100644 --- a/src/axolotl/integrations/rrt/modeling/linear.py +++ b/src/axolotl/integrations/rrt/modeling/linear.py @@ -6,7 +6,6 @@ from peft.utils import transpose from torch import nn - class RelaxedRecursiveDoraLinear(nn.Module): """ A single linear layer that is "shared" across multiple loop iterations, @@ -44,13 +43,19 @@ class RelaxedRecursiveDoraLinear(nn.Module): else: self.register_parameter("bias", None) - self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]) - self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]) + self.lora_A_list = nn.ParameterList( + [nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)] + ) + self.lora_B_list = nn.ParameterList( + [nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)] + ) # rslora self.scaling = alpha / math.sqrt(rank) self.use_dora = use_dora if use_dora: - self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)]) + self.lora_magnitude_vector_list = nn.ParameterList( + [nn.Parameter(torch.ones(out_features)) for _ in range(B)] + ) def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: # calculate L2 norm of weight matrix, column-wise @@ -74,18 +79,24 @@ class RelaxedRecursiveDoraLinear(nn.Module): base_out: torch.Tensor = F.linear(x, w_base, self.bias) - x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype) - tmp = F.linear(x_eye, lora_A) # [hidden_size, rank] - w_dora_full: torch.Tensor = F.linear(tmp, lora_B) - w_dora_full = w_dora_full.t() - - lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None) + lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) if self.use_dora: + x_eye: torch.Tensor = torch.eye( + lora_A.shape[1], device=lora_A.device, dtype=x.dtype + ) + tmp = F.linear(x_eye, lora_A) # [hidden_size, rank] + w_dora_full: torch.Tensor = F.linear(tmp, lora_B) + w_dora_full = w_dora_full.t() + magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx] - w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling) + w_dora_norm: torch.Tensor = self.get_weight_norm( + w_base, w_dora_full.detach(), self.scaling + ) w_dora_norm = w_dora_norm.detach() - scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features] + scale_factor = (magnitude_vector / w_dora_norm).unsqueeze( + 0 + ) # shape [1, out_features] result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out return result_dora