support optional dora
This commit is contained in:
@@ -276,4 +276,5 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=32)
|
# meta-llama/Llama-3.2-1B has 16 hidden layers
|
||||||
|
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512)
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
|
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
|
||||||
# rslora
|
# rslora
|
||||||
self.scaling = alpha / math.sqrt(rank)
|
self.scaling = alpha / math.sqrt(rank)
|
||||||
|
self.use_dora = use_dora
|
||||||
if use_dora:
|
if use_dora:
|
||||||
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
|
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
|
||||||
|
|
||||||
@@ -70,7 +71,6 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
|
|
||||||
lora_A: torch.Tensor = self.lora_A_list[loop_idx]
|
lora_A: torch.Tensor = self.lora_A_list[loop_idx]
|
||||||
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
|
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
|
||||||
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
|
|
||||||
|
|
||||||
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
|
||||||
|
|
||||||
@@ -81,9 +81,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
|
|||||||
|
|
||||||
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
|
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
|
||||||
|
|
||||||
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
|
if self.use_dora:
|
||||||
w_dora_norm = w_dora_norm.detach()
|
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
|
||||||
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
|
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
|
||||||
|
w_dora_norm = w_dora_norm.detach()
|
||||||
|
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
|
||||||
|
|
||||||
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
|
||||||
return result_dora
|
return result_dora
|
||||||
|
return base_out + lora_out * self.scaling
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \
|
||||||
LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel, LlamaRotaryEmbedding
|
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
|
||||||
|
|
||||||
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
|
||||||
|
|
||||||
@@ -23,6 +23,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
|
|||||||
recurse_layers: int = 4
|
recurse_layers: int = 4
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
|
use_dora: bool = True
|
||||||
|
|
||||||
|
|
||||||
class RelaxedRecursiveLlamaMLP(nn.Module):
|
class RelaxedRecursiveLlamaMLP(nn.Module):
|
||||||
@@ -32,9 +33,9 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.intermediate_size = config.intermediate_size
|
self.intermediate_size = config.intermediate_size
|
||||||
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
|
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
|
||||||
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
|
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
|
||||||
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
|
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
def forward(self, x, loop_idx: int):
|
def forward(self, x, loop_idx: int):
|
||||||
@@ -59,16 +60,16 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
|
|||||||
self.is_causal = True
|
self.is_causal = True
|
||||||
|
|
||||||
self.q_proj = RelaxedRecursiveDoraLinear(
|
self.q_proj = RelaxedRecursiveDoraLinear(
|
||||||
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
|
||||||
)
|
)
|
||||||
self.k_proj = RelaxedRecursiveDoraLinear(
|
self.k_proj = RelaxedRecursiveDoraLinear(
|
||||||
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
|
||||||
)
|
)
|
||||||
self.v_proj = RelaxedRecursiveDoraLinear(
|
self.v_proj = RelaxedRecursiveDoraLinear(
|
||||||
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
|
||||||
)
|
)
|
||||||
self.o_proj = RelaxedRecursiveDoraLinear(
|
self.o_proj = RelaxedRecursiveDoraLinear(
|
||||||
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
|
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -327,3 +328,35 @@ class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
|
|||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
def get_nb_trainable_parameters(self) -> tuple[int, int, int]:
|
||||||
|
r"""
|
||||||
|
Returns the number of trainable parameters and the number of all parameters in the model.
|
||||||
|
"""
|
||||||
|
trainable_params = 0
|
||||||
|
all_param = 0
|
||||||
|
lora_params = 0
|
||||||
|
for name, param in self.named_parameters():
|
||||||
|
num_params = param.numel()
|
||||||
|
# if using DS Zero 3 and the weights are initialized empty
|
||||||
|
if num_params == 0 and hasattr(param, "ds_numel"):
|
||||||
|
num_params = param.ds_numel
|
||||||
|
|
||||||
|
# Due to the design of 4bit linear layers from bitsandbytes
|
||||||
|
# one needs to multiply the number of parameters by 2 to get
|
||||||
|
# the correct number of parameters
|
||||||
|
if param.__class__.__name__ == "Params4bit":
|
||||||
|
if hasattr(param, "element_size"):
|
||||||
|
num_bytes = param.element_size()
|
||||||
|
elif not hasattr(param, "quant_storage"):
|
||||||
|
num_bytes = 1
|
||||||
|
else:
|
||||||
|
num_bytes = param.quant_storage.itemsize
|
||||||
|
num_params = num_params * 2 * num_bytes
|
||||||
|
|
||||||
|
all_param += num_params
|
||||||
|
if param.requires_grad:
|
||||||
|
trainable_params += num_params
|
||||||
|
if "lora_" in name:
|
||||||
|
lora_params += num_params
|
||||||
|
|
||||||
|
return trainable_params, all_param, lora_params
|
||||||
|
|||||||
Reference in New Issue
Block a user