support optional dora

This commit is contained in:
Wing Lian
2025-01-20 11:45:06 -05:00
parent 623eaca740
commit b439ed3345
3 changed files with 52 additions and 15 deletions

View File

@@ -276,4 +276,5 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank=
if __name__ == "__main__":
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=32)
# meta-llama/Llama-3.2-1B has 16 hidden layers
convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512)

View File

@@ -48,6 +48,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
# rslora
self.scaling = alpha / math.sqrt(rank)
self.use_dora = use_dora
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
@@ -70,7 +71,6 @@ class RelaxedRecursiveDoraLinear(nn.Module):
lora_A: torch.Tensor = self.lora_A_list[loop_idx]
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
@@ -81,9 +81,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
if self.use_dora:
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling)
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora
return base_out + lora_out * self.scaling

View File

@@ -9,7 +9,7 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \
LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel, LlamaRotaryEmbedding
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
@@ -23,6 +23,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True
class RelaxedRecursiveLlamaMLP(nn.Module):
@@ -32,9 +33,9 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias)
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
@@ -59,16 +60,16 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
)
def forward(
@@ -327,3 +328,35 @@ class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
# Initialize weights and apply final processing
self.post_init()
def get_nb_trainable_parameters(self) -> tuple[int, int, int]:
r"""
Returns the number of trainable parameters and the number of all parameters in the model.
"""
trainable_params = 0
all_param = 0
lora_params = 0
for name, param in self.named_parameters():
num_params = param.numel()
# if using DS Zero 3 and the weights are initialized empty
if num_params == 0 and hasattr(param, "ds_numel"):
num_params = param.ds_numel
# Due to the design of 4bit linear layers from bitsandbytes
# one needs to multiply the number of parameters by 2 to get
# the correct number of parameters
if param.__class__.__name__ == "Params4bit":
if hasattr(param, "element_size"):
num_bytes = param.element_size()
elif not hasattr(param, "quant_storage"):
num_bytes = 1
else:
num_bytes = param.quant_storage.itemsize
num_params = num_params * 2 * num_bytes
all_param += num_params
if param.requires_grad:
trainable_params += num_params
if "lora_" in name:
lora_params += num_params
return trainable_params, all_param, lora_params