From b439ed33457cc8dbc483c4d961b751d464f464b1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 20 Jan 2025 11:45:06 -0500 Subject: [PATCH] support optional dora --- src/axolotl/integrations/rrt/cli/convert.py | 3 +- .../integrations/rrt/modeling/linear.py | 15 +++--- .../rrt/modeling/modeling_rrt_llama.py | 49 ++++++++++++++++--- 3 files changed, 52 insertions(+), 15 deletions(-) diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index f755b38e8..38d994c93 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -276,4 +276,5 @@ def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12, rank= if __name__ == "__main__": - convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=32) + # meta-llama/Llama-3.2-1B has 16 hidden layers + convert_llama_to_rrt("meta-llama/Llama-3.2-1B", "/tmp/rrt_model", recurse_layers=4, rank=256, alpha=512) diff --git a/src/axolotl/integrations/rrt/modeling/linear.py b/src/axolotl/integrations/rrt/modeling/linear.py index c5df20e57..8cc86094c 100644 --- a/src/axolotl/integrations/rrt/modeling/linear.py +++ b/src/axolotl/integrations/rrt/modeling/linear.py @@ -48,6 +48,7 @@ class RelaxedRecursiveDoraLinear(nn.Module): self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]) # rslora self.scaling = alpha / math.sqrt(rank) + self.use_dora = use_dora if use_dora: self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)]) @@ -70,7 +71,6 @@ class RelaxedRecursiveDoraLinear(nn.Module): lora_A: torch.Tensor = self.lora_A_list[loop_idx] lora_B: torch.Tensor = self.lora_B_list[loop_idx] - magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx] base_out: torch.Tensor = F.linear(x, w_base, self.bias) @@ -81,9 +81,12 @@ class RelaxedRecursiveDoraLinear(nn.Module): lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None) - w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling) - w_dora_norm = w_dora_norm.detach() - scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features] + if self.use_dora: + magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx] + w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach(), self.scaling) + w_dora_norm = w_dora_norm.detach() + scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features] - result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out - return result_dora + result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out + return result_dora + return base_out + lora_out * self.scaling diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index 6b4595435..4b64c7bf4 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -9,7 +9,7 @@ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \ - LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel, LlamaRotaryEmbedding + LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear @@ -23,6 +23,7 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig): recurse_layers: int = 4 rank: int alpha: int + use_dora: bool = True class RelaxedRecursiveLlamaMLP(nn.Module): @@ -32,9 +33,9 @@ class RelaxedRecursiveLlamaMLP(nn.Module): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias) - self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias) - self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias) + self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) + self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) + self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x, loop_idx: int): @@ -59,16 +60,16 @@ class RelaxedRecursiveLlamaAttention(nn.Module): self.is_causal = True self.q_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias + config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora ) self.k_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias + config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora ) self.v_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias + config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora ) self.o_proj = RelaxedRecursiveDoraLinear( - config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias + config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora ) def forward( @@ -327,3 +328,35 @@ class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM): # Initialize weights and apply final processing self.post_init() + def get_nb_trainable_parameters(self) -> tuple[int, int, int]: + r""" + Returns the number of trainable parameters and the number of all parameters in the model. + """ + trainable_params = 0 + all_param = 0 + lora_params = 0 + for name, param in self.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes + # one needs to multiply the number of parameters by 2 to get + # the correct number of parameters + if param.__class__.__name__ == "Params4bit": + if hasattr(param, "element_size"): + num_bytes = param.element_size() + elif not hasattr(param, "quant_storage"): + num_bytes = 1 + else: + num_bytes = param.quant_storage.itemsize + num_params = num_params * 2 * num_bytes + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + if "lora_" in name: + lora_params += num_params + + return trainable_params, all_param, lora_params