This commit is contained in:
Wing Lian
2025-01-19 20:11:26 -05:00
parent 257231ac46
commit daa9408233
3 changed files with 494 additions and 0 deletions

View File

@@ -0,0 +1,99 @@
from pathlib import Path
from typing import re
import safetensors
import torch
from huggingface_hub import snapshot_download
from tqdm import tqdm
from transformers import AutoConfig
def extract_layer_number(key):
"""Extract layer number from parameter key."""
match = re.search(r'layers\.(\d+)\.', key)
return int(match.group(1)) if match else None
def iter_parameter_weights(model_path, device="cpu"):
"""
iterator over parameter weights in the model shards
:param model_path: Path to model shards
:param device: Computing device
:return: generator yielding (parameter key, parameter weight, layer index) tuples
"""
shards = list(model_path.glob('model*.safetensors'))
if not shards:
raise ValueError(f"No model shards found in {model_path}")
for shard in tqdm(shards, desc="Processing shards"):
with safetensors.safe_open(shard, framework='pt', device=device) as f:
for key in f.keys():
layer_idx = extract_layer_number(key)
weight = f.get_tensor(key)
yield key, weight, layer_idx
def iter_recursive_parameter_weights(model_path, modules_to_recurse: list[str], device="cpu", recurse_layers=12):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path):
# get the matching module name in modules_to_recurse for the current parameter key
matched_module_name = next(
(module for module in modules_to_recurse if module in key),
None
)
if matched_module_name is None:
if "input_layernorm" in key:
# map to input_layernorm_list in the recursive layers and account for the layer_idx and loop_idx
yield
else:
yield key, weight
recurse_idx = layer_idx % recurse_layers
suffix = f"{recurse_idx}.{matched_module_name}"
prefix = f"model.layers.{suffix}."
if rrt_avg_model_state_dict.get(suffix) is None:
# setup as storage for suffix with torch.stack
rrt_avg_model_state_dict[suffix] = torch.stack([weight.to(torch.float32).detach().cpu()])
else:
rrt_avg_model_state_dict[suffix] = torch.cat([rrt_avg_model_state_dict[suffix], weight.to(torch.float32).detach().cpu()])
for module_name in modules_to_recurse:
for recurse_idx in range(recurse_layers):
suffix = f"{recurse_idx}.{module_name}"
prefix = f"model.layers.{suffix}."
avg_weight = rrt_avg_model_state_dict[suffix].mean(dim=0)
yield f"{prefix}.weight", avg_weight
def convert_llama_to_rrt(model_name, output_dir, recurse_layers: int = 12):
modules_to_recurse = [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.o_proj",
"mlp.down_proj",
"mlp.gate_proj",
"mlp.up_proj",
]
config = AutoConfig.from_pretrained(model_name)
num_hidden_layers = config.num_hidden_layers
if num_hidden_layers % recurse_layers != 0:
raise ValueError(
f"The number of hidden layers ({num_hidden_layers}) in the model must be "
f"divisible by the recurse layers ({recurse_layers})"
)
model_path = Path(snapshot_download(model_name))
# create a new state_dict to store the RRT model weights
rrt_model_state_dict = {}
for key, weight in iter_recursive_parameter_weights(model_path, modules_to_recurse, device="cpu", recurse_layers=recurse_layers):
rrt_model_state_dict[key] = weight.to(torch.bfloat16).detach().cpu()
# split_torch_state_dict_into_shards(...)

View File

@@ -0,0 +1,73 @@
import torch
import torch.nn.functional as F
from torch import nn, transpose
class RelaxedRecursiveDoraLinear(nn.Module):
"""
A single linear layer that is "shared" across multiple loop iterations,
but each iteration has its own DoRA offsets (A_i, B_i, magnitude_i).
The constructor expects you to specify:
- in_features, out_features
- B: number of loop iterations (i.e., how many times we "unroll")
- fan_in_fan_out: pass True if your underlying base weight is transposed, etc.
The forward(...) expects an additional argument "loop_idx" in [0..B-1],
which picks out the iteration-specific DoRA offsets.
"""
def __init__(
self,
in_features: int,
out_features: int,
B: int,
rank: int,
fan_in_fan_out: bool = False,
bias: bool = True,
use_dora: bool = True,
):
super().__init__()
self.B = B
self.fan_in_fan_out = fan_in_fan_out
self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
self.use_bias = bias
if self.use_bias:
self.bias = nn.Parameter(torch.zeros(out_features))
else:
self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList([nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)])
self.lora_B_list = nn.ParameterList([nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)])
if use_dora:
self.lora_magnitude_vector_list = nn.ParameterList([nn.Parameter(torch.ones(out_features)) for _ in range(B)])
def forward(self, x, loop_idx: int):
"""
:param x: hidden state of shape (batch_size, seq_len, in_features)
:param loop_idx:
:return:
"""
w_base = self.weight_base
w_base = w_base.to(x.dtype)
lora_A: torch.Tensor = self.lora_A_list[loop_idx]
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
magnitude_vector: torch.Tensor = self.lora_magnitude_vector_list[loop_idx]
base_out: torch.Tensor = F.linear(x, transpose(w_base, self.fan_in_fan_out), self.bias)
x_eye: torch.Tensor = torch.eye(lora_A.shape[1], device=lora_A.device, dtype=x.dtype)
w_dora_full: torch.Tensor = lora_B(lora_A(x_eye))
lora_out: torch.Tensor = F.linear(x, w_dora_full, bias=None)
w_dora_norm: torch.Tensor = self.get_weight_norm(w_base, w_dora_full.detach())
w_dora_norm = w_dora_norm.detach()
scale_factor = (magnitude_vector / w_dora_norm).unsqueeze(0) # shape [1, out_features]
result_dora = (scale_factor - 1) * base_out + scale_factor * lora_out
return result_dora

View File

@@ -0,0 +1,322 @@
from typing import Tuple, Optional, Unpack, Callable, Union
import torch
from torch import nn
from transformers import LlamaConfig, Cache, logger, DynamicCache
from transformers.activations import ACT2FN
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \
LlamaForCausalLM, LlamaPreTrainedModel, LlamaModel, LlamaRotaryEmbedding
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
class RelaxedRecursiveLlamaConfig(LlamaConfig):
"""
Configuration for Relaxed Recursive Llama.
"""
recurse_layers: int
rank: int
class RelaxedRecursiveLlamaMLP(nn.Module):
def __init__(self, config: RelaxedRecursiveLlamaConfig):
super().__init__()
recurse_loops = config.num_layers // config.recurse_layers
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, bias=config.mlp_bias)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), loop_idx)
return down_proj
class RelaxedRecursiveLlamaAttention(nn.Module):
"""
A single attention layer of the Relaxed Recursive Llama.
"""
def __init__(self, config: RelaxedRecursiveLlamaConfig, layer_idx: int):
super().__init__()
recurse_loops = config.num_layers // config.recurse_layers
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, bias=config.attention_bias
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, bias=config.attention_bias
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
loop_idx: int,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, loop_idx)
return attn_output, attn_weights
class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
"""
A single layer of the Relaxed Recursive Llama decoder.
"""
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
recurse_loops = config.num_layers // config.recurse_layers
self.hidden_size = config.hidden_size
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx)
self.mlp = RelaxedRecursiveLlamaMLP(config)
self.input_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)])
self.post_attention_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)])
def forward(
self,
hidden_states: torch.Tensor,
loop_idx: int,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
# Self Attention
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
loop_idx=loop_idx,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm_list[loop_idx](hidden_states)
hidden_states = self.mlp(hidden_states, loop_idx)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
return outputs
class RelaxedRecursiveLlamaModel(LlamaModel):
def __init__(self, config):
super(LlamaModel, self).__init__(config)
self.recurse_loops = config.num_layers // config.recurse_layers
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList(
[RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.recurse_layers)]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
hidden_states = inputs_embeds
# create position embeddings to be shared across the decoder layers
position_embeddings = self.rotary_emb(hidden_states, position_ids)
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for loop_idx in range(self.recurse_loops):
for decoder_layer in self.layers[: self.config.recurse_layers]:
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
loop_idx,
causal_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
else:
layer_outputs = decoder_layer(
hidden_states,
loop_idx,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
output = BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
return output if return_dict else output.to_tuple()
class RelaxedRecursiveLlamaForCausalLM(LlamaForCausalLM):
def __init__(self, config):
super(LlamaForCausalLM, self).__init__(config)
self.model = RelaxedRecursiveLlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()