chore: lint

This commit is contained in:
Wing Lian
2025-01-24 13:29:54 -05:00
parent 0af78a9882
commit 791c38dcc3
7 changed files with 214 additions and 75 deletions

View File

@@ -4,15 +4,8 @@ Axolotl Plugin for Relaxed Recursive Transformers
import logging
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
LOG = logging.getLogger(__name__)
@@ -30,18 +23,3 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin):
"Registering Relaxed Recursive Transformers modeling with transformers"
)
register_rrt_model()
def register_rrt_model():
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -1,3 +1,7 @@
"""
Axolotl config args for Relaxed Recursive Transformers plugin
"""
from pydantic import BaseModel
@@ -5,4 +9,3 @@ class RelaxedRecursiveTransformerArgs(BaseModel):
"""
Arguments pertaining to the Relaxed Recursive Transformer model.
"""
...

View File

@@ -1,3 +1,6 @@
"""
cli script for converting a pretrained model to a relaxed recursive transformer model
"""
import json
import logging
import math
@@ -52,7 +55,7 @@ def iter_recursive_parameter_weights(
):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict = {}
rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {}
# iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
@@ -93,6 +96,7 @@ def low_rank_decomposition(
:param max_rank: The maximum rank of the decomposition
:return: A tuple of tensors (L, R)
"""
# pylint: disable=invalid-name
assert (
weight.dim() == 2
), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
@@ -135,7 +139,9 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True)
delta_for_svd = final_weight - base_weight
# Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank)
lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name
delta_for_svd, rank
)
if use_dora:
lora_weight = lora_B @ lora_A
@@ -218,6 +224,7 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
)
# pylint: disable=duplicate-code
# Save index if sharded
index = None
if state_dict_split.is_sharded:
@@ -355,7 +362,7 @@ if __name__ == "__main__":
# meta-llama/Llama-3.2-3B has 28 hidden layers
convert_llama_to_rrt(
"meta-llama/Llama-3.2-3B",
"/tmp/rrt_model",
"/tmp/rrt_model", # nosec
recurse_layers=4,
rank=256,
alpha=512,

View File

@@ -1,2 +1,25 @@
"""
module for modeling relaxed recursive transformers model
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
from .modeling_rrt_llama import (
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
def register_rrt_model():
pass
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -1,3 +1,6 @@
"""
module for custom configuration for relaxed recursive transformers model
"""
from transformers import LlamaConfig
@@ -6,8 +9,8 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
Configuration for Relaxed Recursive Llama.
"""
model_type = "llama-rrt"
recurse_layers: int = 4
model_type: str = "llama-rrt"
recurse_layers: int = 4
rank: int
alpha: int
use_dora: bool = True

View File

@@ -1,3 +1,6 @@
"""
module for the shared linear layer for the relaxed recursive transformers model
"""
import math
import torch
@@ -24,7 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
self,
in_features: int,
out_features: int,
B: int,
B: int, # pylint: disable=invalid-name
rank: int,
alpha: int,
fan_in_fan_out: bool = False,
@@ -32,7 +35,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
use_dora: bool = True,
):
super().__init__()
self.B = B
self.B = B # pylint: disable=invalid-name
self.fan_in_fan_out = fan_in_fan_out
self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
@@ -43,10 +46,10 @@ class RelaxedRecursiveDoraLinear(nn.Module):
else:
self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList(
self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
)
self.lora_B_list = nn.ParameterList(
self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
)
# rslora
@@ -75,8 +78,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
w_base = self.weight_base
w_base = w_base.to(x.dtype)
lora_A: torch.Tensor = self.lora_A_list[loop_idx]
lora_B: torch.Tensor = self.lora_B_list[loop_idx]
lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name
loop_idx
]
lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name
loop_idx
]
base_out: torch.Tensor = F.linear(x, w_base, self.bias)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling

View File

@@ -1,22 +1,33 @@
import logging
from typing import Tuple, Optional, Unpack, Callable, Union
from typing import Callable, Optional, Tuple, Union, Unpack
import torch
from torch import nn
from transformers import LlamaConfig, Cache, DynamicCache
from transformers import Cache, DynamicCache, LlamaConfig
from transformers.activations import ACT2FN
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__)
# pylint: skip-file
# mypy: ignore-errors
class RelaxedRecursiveLlamaMLP(nn.Module):
def __init__(self, config: RelaxedRecursiveLlamaConfig):
super().__init__()
@@ -24,13 +35,40 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora)
self.gate_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.up_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.down_proj = RelaxedRecursiveDoraLinear(
self.intermediate_size,
self.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), loop_idx)
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx),
loop_idx,
)
return down_proj
@@ -44,23 +82,51 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.hidden_size,
config.num_attention_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora
config.num_attention_heads * self.head_dim,
config.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
)
def forward(
@@ -71,32 +137,46 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
loop_idx: int,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
query_states = (
self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
if self.config._attn_implementation == "sdpa" and kwargs.get(
"output_attentions", False
):
logger.warning(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
@@ -111,8 +191,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, loop_idx)
return attn_output, attn_weights
return attn_output, attn_weights # pylint: disable=return-value
class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
@@ -125,12 +204,24 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
recurse_loops = config.num_hidden_layers // config.recurse_layers
self.hidden_size = config.hidden_size
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx)
self.self_attn = RelaxedRecursiveLlamaAttention(
config=config, layer_idx=layer_idx
)
self.mlp = RelaxedRecursiveLlamaMLP(config)
self.input_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)])
self.post_attention_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)])
self.input_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
self.post_attention_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
def forward(
self,
@@ -142,9 +233,13 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states
hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
@@ -186,9 +281,14 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.recurse_layers)]
[
RelaxedRecursiveLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.recurse_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
@@ -211,15 +311,25 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
@@ -234,16 +344,24 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
past_key_values = DynamicCache()
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
)
hidden_states = inputs_embeds