diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py new file mode 100644 index 000000000..e3674b4be --- /dev/null +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -0,0 +1,99 @@ +from pathlib import Path +from typing import re + +import safetensors +import torch +from huggingface_hub import snapshot_download +from tqdm import tqdm +from transformers import AutoConfig + + +def extract_layer_number(key): + """Extract layer number from parameter key.""" + match = re.search(r'layers\.(\d+)\.', key) + return int(match.group(1)) if match else None + + +def iter_parameter_weights(model_path, device="cpu"): + """ + iterator over parameter weights in the model shards + + :param model_path: Path to model shards + :param device: Computing device + :return: generator yielding (parameter key, parameter weight, layer index) tuples + """ + shards = list(model_path.glob('model*.safetensors')) + if not shards: + raise ValueError(f"No model shards found in {model_path}") + + for shard in tqdm(shards, desc="Processing shards"): + with safetensors.safe_open(shard, framework='pt', device=device) as f: + for key in f.keys(): + layer_idx = extract_layer_number(key) + weight = f.get_tensor(key) + yield key, weight, layer_idx + +def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="cpu", recurse_layers=12): + # setup placeholder state_dict for recursive weights, need to keep in float32 precision + # to avoid precision loss when averaging weights across layers + rrt_avg_model_state_dict = {} + + # iterate over all parameter weights in the model shards + for key, weight, layer_idx in iter_parameter_weights(model_path): + # get the matching module name in modules_to_recurse for the current parameter key + matched_module_name = next( + (module for module in modules_to_recurse if module in key), + None + ) + if matched_module_name is None: + if "input_layernorm" in key: + # map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx + yield + else: + yield key, weight + + recurse_idx = layer_idx % recurse_layers + suffix = f"{recurse_idx}.{matched_module_name}" + prefix = f"model.layers.{suffix}." + if rrt_avg_model_state_dict.get(suffix) is None: + # setup as storage for suffix with torch.stack + rrt_avg_model_state_dict[suffix] = torch.stack([weight.to(torch.float32).detach().cpu()]) + else: + rrt_avg_model_state_dict[suffix] = torch.cat([rrt_avg_model_state_dict[suffix], weight.to(torch.float32).detach().cpu()]) + + for module_name in modules_to_recurse: + for recurse_idx in range(recurse_layers): + suffix = f"{recurse_idx}.{module_name}" + prefix = f"model.layers.{suffix}." + avg_weight = rrt_avg_model_state_dict[suffix].mean(dim=0) + yield f"{prefix}.weight", avg_weight + + +def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12): + modules_to_recurse = [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.down_proj", + "mlp.gate_proj", + "mlp.up_proj", + ] + + config = AutoConfig.from_pretrained(model_name) + num_hidden_layers = config.num_hidden_layers + if num_hidden_layers % recurse_layers != 0: + raise ValueError( + f"The number of hidden layers ({num_hidden_layers}) in the model must be " + f"divisible by the recurse layers ({recurse_layers})" + ) + + model_path = Path(snapshot_download(model_name)) + + # create a new state_dict to store the RRT model weights + rrt_model_state_dict = {} + + for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="cpu", recurse_layers=recurse_layers): + rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu() + + # split_torch_state_dict_into_shards(...) diff --git a/src/axolotl/integrations/rrt/modeling/linear.py b/src/axolotl/integrations/rrt/modeling/linear.py new file mode 100644 index 000000000..270e8875a --- /dev/null +++ b/src/axolotl/integrations/rrt/modeling/linear.py @@ -0,0 +1,73 @@ +import torch +import torch.nn.functional as F +from torch import nn, transpose + + +class RelaxedRecursiveDoraLinear(nn.Module): + """ + A single linear layer that is "shared" across multiple loop iterations, + but each iteration has its own DoRA offsets (A_i, B_i, magnitude_i). + + The constructor expects you to specify: + - in_features, out_features + - B: number of loop iterations (i.e., how many times we "unroll") + - fan_in_fan_out: pass True if your underlying base weight is transposed, etc. + + The forward(...) expects an additional argument "loop_idx" in [0..B-1], + which picks out the iteration-specific DoRA offsets. + """ + + def __init__( + self, + in_features: int, + out_features: int, + B: int, + rank: int, + fan_in_fan_out: bool = False, + bias: bool = True, + use_dora: bool = True, + ): + super().__init__() + self.B = B + self.fan_in_fan_out = fan_in_fan_out + + self.weight_base = nn.Parameter(torch.empty(out_features, in_features)) + + self.use_bias = bias + if self.use_bias: + self.bias = nn.Parameter(torch.zeros(out_features)) + else: + self.register_parameter("bias", None) + + self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]) + self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]) + if use_dora: + self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)]) + + def forward(self, x, loop_idx: int): + """ + + :param x: hidden state of shape (batch_size, seq_len, in_features) + :param loop_idx: + :return: + """ + w_base = self.weight_base + w_base = w_base.to(x.dtype) + + lora_A: torch.Tensor = self.lora_A_list[loop_idx] + lora_B: torch.Tensor = self.lora_B_list[loop_idx] + magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx] + + base_out: torch.Tensor = F.linear(x, transpose(w_base, self.fan_in_fan_out), self.bias) + + x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype) + w_dora_full: torch.Tensor = lora_B(lora_A(x_eye)) + + lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None) + + w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach()) + w_dora_norm = w_dora_norm.detach() + scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features] + + result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out + return result_dora diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index e69de29bb..02497ef86 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -0,0 +1,322 @@ +from typing import Tuple, Optional, Unpack, Callable, Union + +import torch +from torch import nn +from transformers import LlamaConfig, Cache, logger, DynamicCache +from transformers.activations import ACT2FN +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import BaseModelOutputWithPast +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \ + LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel, LlamaRotaryEmbedding + +from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear + + +class RelaxedRecursiveLlamaConfig(LlamaConfig): + """ + Configuration for Relaxed Recursive Llama. + """ + + recurse_layers: int + rank: int + + +class RelaxedRecursiveLlamaMLP(nn.Module): + def __init__(self, config: RelaxedRecursiveLlamaConfig): + super().__init__() + recurse_loops = config.num_layers // config.recurse_layers + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias) + self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias) + self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x, loop_idx: int): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), loop_idx) + return down_proj + + +class RelaxedRecursiveLlamaAttention(nn.Module): + """ + A single attention layer of the Relaxed Recursive Llama. + """ + + def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int): + super().__init__() + recurse_loops = config.num_layers // config.recurse_layers + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = config.attention_dropout + self.is_causal = True + + self.q_proj = RelaxedRecursiveDoraLinear( + config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias + ) + self.k_proj = RelaxedRecursiveDoraLinear( + config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias + ) + self.v_proj = RelaxedRecursiveDoraLinear( + config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias + ) + self.o_proj = RelaxedRecursiveDoraLinear( + config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, bias=config.attention_bias + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + loop_idx: int, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): + logger.warning_once( + "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " + 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output, loop_idx) + return attn_output, attn_weights + + + +class RelaxedRecursiveLlamaDecoderLayer(nn.Module): + """ + A single layer of the Relaxed Recursive Llama decoder. + """ + + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + recurse_loops = config.num_layers // config.recurse_layers + self.hidden_size = config.hidden_size + + self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx) + + self.mlp = RelaxedRecursiveLlamaMLP(config) + + self.input_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)]) + self.post_attention_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)]) + + def forward( + self, + hidden_states: torch.Tensor, + loop_idx: int, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + residual = hidden_states + + hidden_states = self.input_layernorm_list[loop_idx](hidden_states) + + # Self Attention + hidden_states, self_attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + loop_idx=loop_idx, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm_list[loop_idx](hidden_states) + hidden_states = self.mlp(hidden_states, loop_idx) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + + return outputs + + +class RelaxedRecursiveLlamaModel(LlamaModel): + def __init__(self, config): + super(LlamaModel, self).__init__(config) + self.recurse_loops = config.num_layers // config.recurse_layers + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.recurse_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # create position embeddings to be shared across the decoder layers + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for loop_idx in range(self.recurse_loops): + for decoder_layer in self.layers[: self.config.recurse_layers]: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + loop_idx, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + loop_idx, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + output = BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + return output if return_dict else output.to_tuple() + + +class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM): + def __init__(self, config): + super(LlamaForCausalLM, self).__init__(config) + self.model = RelaxedRecursiveLlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() +