hopefully fix the lora/dora logic

This commit is contained in:
Wing Lian
2025-01-21 12:30:00 -05:00
parent 08a4e8a7fb
commit e3393042e5
2 changed files with 32 additions and 13 deletions

View File

@@ -264,7 +264,13 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
def convert_llama_to_rrt(
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None
model_name,
output_dir,
recurse_layers: int = 12,
rank=32,
alpha=32,
device=None,
use_dora=True,
):
if not device:
if torch.backends.mps.is_available():
@@ -299,6 +305,7 @@ def convert_llama_to_rrt(
"recurse_layers": recurse_layers,
"rank": rank,
"alpha": alpha,
"use_dora": use_dora,
}
)
config.save_pretrained(output_dir)
@@ -343,4 +350,5 @@ if __name__ == "__main__":
recurse_layers=4,
rank=256,
alpha=512,
use_dora=False,
)

View File

@@ -6,7 +6,6 @@ from peft.utils import transpose
from torch import nn
class RelaxedRecursiveDoraLinear(nn.Module):
"""
A single linear layer that is "shared" across multiple loop iterations,
@@ -44,13 +43,19 @@ class RelaxedRecursiveDoraLinear(nn.Module):
else:
self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)])
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
self.lora_A_list = nn.ParameterList(
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
)
self.lora_B_list = nn.ParameterList(
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
)
# rslora
self.scaling = alpha / math.sqrt(rank)
self.use_dora = use_dora
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
self.lora_magnitude_vector_list = nn.ParameterList(
[nn.Parameter(torch.ones(out_features)) for _ in range(B)]
)
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
# calculate L2 norm of weight matrix, column-wise
@@ -74,18 +79,24 @@ class RelaxedRecursiveDoraLinear(nn.Module):
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype)
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B)
if self.use_dora:
x_eye: torch.Tensor = torch.eye(
lora_A.shape[1], device=lora_A.device, dtype=x.dtype
)
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
w_dora_full = w_dora_full.t()
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
w_dora_norm: torch.Tensor = self.get_weight_norm(
w_base, w_dora_full.detach(), self.scaling
)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(
0
) # shape [1, out_features]
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora