hopefully fix the lora/dora logic
This commit is contained in:
@@ -264,7 +264,13 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
|
|||||||
|
|
||||||
|
|
||||||
def convert_llama_to_rrt(
|
def convert_llama_to_rrt(
|
||||||
model_name, output_dir, recurse_layers: int = 12, rank=32, alpha=32, device=None
|
model_name,
|
||||||
|
output_dir,
|
||||||
|
recurse_layers: int = 12,
|
||||||
|
rank=32,
|
||||||
|
alpha=32,
|
||||||
|
device=None,
|
||||||
|
use_dora=True,
|
||||||
):
|
):
|
||||||
if not device:
|
if not device:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@@ -299,6 +305,7 @@ def convert_llama_to_rrt(
|
|||||||
"recurse_layers": recurse_layers,
|
"recurse_layers": recurse_layers,
|
||||||
"rank": rank,
|
"rank": rank,
|
||||||
"alpha": alpha,
|
"alpha": alpha,
|
||||||
|
"use_dora": use_dora,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
config.save_pretrained(output_dir)
|
config.save_pretrained(output_dir)
|
||||||
@@ -343,4 +350,5 @@ if __name__ == "__main__":
|
|||||||
recurse_layers=4,
|
recurse_layers=4,
|
||||||
rank=256,
|
rank=256,
|
||||||
alpha=512,
|
alpha=512,
|
||||||
|
use_dora=False,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from peft.utils import transpose
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RelaxedRecursiveDoraLinear(nn.Module):
|
class RelaxedRecursiveDoraLinear(nn.Module):
|
||||||
"""
|
"""
|
||||||
A single linear layer that is "shared" across multiple loop iterations,
|
A single linear layer that is "shared" across multiple loop iterations,
|
||||||
@@ -44,13 +43,19 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)])
|
self.lora_A_list = nn.ParameterList(
|
||||||
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
|
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
|
||||||
|
)
|
||||||
|
self.lora_B_list = nn.ParameterList(
|
||||||
|
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
|
||||||
|
)
|
||||||
# rslora
|
# rslora
|
||||||
self.scaling = alpha / math.sqrt(rank)
|
self.scaling = alpha / math.sqrt(rank)
|
||||||
self.use_dora = use_dora
|
self.use_dora = use_dora
|
||||||
if use_dora:
|
if use_dora:
|
||||||
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
|
self.lora_magnitude_vector_list = nn.ParameterList(
|
||||||
|
[nn.Parameter(torch.ones(out_features)) for _ in range(B)]
|
||||||
|
)
|
||||||
|
|
||||||
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
|
def get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor:
|
||||||
# calculate L2 norm of weight matrix, column-wise
|
# calculate L2 norm of weight matrix, column-wise
|
||||||
@@ -74,18 +79,24 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
|
|
||||||
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
||||||
|
|
||||||
x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype)
|
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B)
|
||||||
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
|
|
||||||
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
|
|
||||||
w_dora_full = w_dora_full.t()
|
|
||||||
|
|
||||||
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
|
|
||||||
|
|
||||||
if self.use_dora:
|
if self.use_dora:
|
||||||
|
x_eye: torch.Tensor = torch.eye(
|
||||||
|
lora_A.shape[1], device=lora_A.device, dtype=x.dtype
|
||||||
|
)
|
||||||
|
tmp = F.linear(x_eye, lora_A) # [hidden_size, rank]
|
||||||
|
w_dora_full: torch.Tensor = F.linear(tmp, lora_B)
|
||||||
|
w_dora_full = w_dora_full.t()
|
||||||
|
|
||||||
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
|
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
|
||||||
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
|
w_dora_norm: torch.Tensor = self.get_weight_norm(
|
||||||
|
w_base, w_dora_full.detach(), self.scaling
|
||||||
|
)
|
||||||
w_dora_norm = w_dora_norm.detach()
|
w_dora_norm = w_dora_norm.detach()
|
||||||
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
|
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(
|
||||||
|
0
|
||||||
|
) # shape [1, out_features]
|
||||||
|
|
||||||
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
||||||
return result_dora
|
return result_dora
|
||||||
|
|||||||
Reference in New Issue
Block a user