From 791c38dcc36819666800c0cfe8a7392617cbb44e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 24 Jan 2025 13:29:54 -0500 Subject: [PATCH] chore: lint --- src/axolotl/integrations/rrt/__init__.py | 22 -- src/axolotl/integrations/rrt/args.py | 5 +- src/axolotl/integrations/rrt/cli/convert.py | 13 +- .../integrations/rrt/modeling/__init__.py | 25 ++- .../rrt/modeling/configuration_rrt_llama.py | 7 +- .../integrations/rrt/modeling/linear.py | 19 +- .../rrt/modeling/modeling_rrt_llama.py | 198 ++++++++++++++---- 7 files changed, 214 insertions(+), 75 deletions(-) diff --git a/src/axolotl/integrations/rrt/__init__.py b/src/axolotl/integrations/rrt/__init__.py index 04372409e..e61939c78 100644 --- a/src/axolotl/integrations/rrt/__init__.py +++ b/src/axolotl/integrations/rrt/__init__.py @@ -4,15 +4,8 @@ Axolotl Plugin for Relaxed Recursive Transformers import logging -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM - from axolotl.integrations.base import BasePlugin from axolotl.integrations.rrt.modeling import register_rrt_model -from axolotl.integrations.rrt.modeling.modeling_rrt_llama import ( - RelaxedRecursiveLlamaConfig, - RelaxedRecursiveLlamaForCausalLM, - RelaxedRecursiveLlamaModel, -) LOG = logging.getLogger(__name__) @@ -30,18 +23,3 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin): "Registering Relaxed Recursive Transformers modeling with transformers" ) register_rrt_model() - - -def register_rrt_model(): - """ - Register Relaxed Recursive Transformers model with transformers - """ - - # Register configs - AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig) - - # Register models - AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) - AutoModelForCausalLM.register( - RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM - ) diff --git a/src/axolotl/integrations/rrt/args.py b/src/axolotl/integrations/rrt/args.py index 18cf0d360..2e03995bf 100644 --- a/src/axolotl/integrations/rrt/args.py +++ b/src/axolotl/integrations/rrt/args.py @@ -1,3 +1,7 @@ +""" +Axolotl config args for Relaxed Recursive Transformers plugin +""" + from pydantic import BaseModel @@ -5,4 +9,3 @@ class RelaxedRecursiveTransformerArgs(BaseModel): """ Arguments pertaining to the Relaxed Recursive Transformer model. """ - ... diff --git a/src/axolotl/integrations/rrt/cli/convert.py b/src/axolotl/integrations/rrt/cli/convert.py index 5ad7427f5..f6a9542ae 100644 --- a/src/axolotl/integrations/rrt/cli/convert.py +++ b/src/axolotl/integrations/rrt/cli/convert.py @@ -1,3 +1,6 @@ +""" +cli script for converting a pretrained model to a relaxed recursive transformer model +""" import json import logging import math @@ -52,7 +55,7 @@ def iter_recursive_parameter_weights( ): # setup placeholder state_dict for recursive weights, need to keep in float32 precision # to avoid precision loss when averaging weights across layers - rrt_avg_model_state_dict = {} + rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {} # iterate over all parameter weights in the model shards for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): @@ -93,6 +96,7 @@ def low_rank_decomposition( :param max_rank: The maximum rank of the decomposition :return: A tuple of tensors (L, R) """ + # pylint: disable=invalid-name assert ( weight.dim() == 2 ), f"Only support 2D matrix, but input has {weight.dim()} dimensions." @@ -135,7 +139,9 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True) delta_for_svd = final_weight - base_weight # Low-rank factorization of the delta direction - lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank) + lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name + delta_for_svd, rank + ) if use_dora: lora_weight = lora_B @ lora_A @@ -218,6 +224,7 @@ def save_state_dict_to_safetensors(state_dict, save_directory): state_dict_split = split_torch_state_dict_into_shards( state_dict, filename_pattern=filename_pattern, max_shard_size="1GB" ) + # pylint: disable=duplicate-code # Save index if sharded index = None if state_dict_split.is_sharded: @@ -355,7 +362,7 @@ if __name__ == "__main__": # meta-llama/Llama-3.2-3B has 28 hidden layers convert_llama_to_rrt( "meta-llama/Llama-3.2-3B", - "/tmp/rrt_model", + "/tmp/rrt_model", # nosec recurse_layers=4, rank=256, alpha=512, diff --git a/src/axolotl/integrations/rrt/modeling/__init__.py b/src/axolotl/integrations/rrt/modeling/__init__.py index b30629bb4..100ad47f0 100644 --- a/src/axolotl/integrations/rrt/modeling/__init__.py +++ b/src/axolotl/integrations/rrt/modeling/__init__.py @@ -1,2 +1,25 @@ +""" +module for modeling relaxed recursive transformers model +""" +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig +from .modeling_rrt_llama import ( + RelaxedRecursiveLlamaForCausalLM, + RelaxedRecursiveLlamaModel, +) + + def register_rrt_model(): - pass + """ + Register Relaxed Recursive Transformers model with transformers + """ + + # Register configs + AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig) + + # Register models + AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel) + AutoModelForCausalLM.register( + RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM + ) diff --git a/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py index 88e743088..0344078ec 100644 --- a/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/configuration_rrt_llama.py @@ -1,3 +1,6 @@ +""" +module for custom configuration for relaxed recursive transformers model +""" from transformers import LlamaConfig @@ -6,8 +9,8 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig): Configuration for Relaxed Recursive Llama. """ - model_type = "llama-rrt" - recurse_layers: int = 4 + model_type: str = "llama-rrt" + recurse_layers: int = 4 rank: int alpha: int use_dora: bool = True diff --git a/src/axolotl/integrations/rrt/modeling/linear.py b/src/axolotl/integrations/rrt/modeling/linear.py index f4d4b95bb..a4d1d1de8 100644 --- a/src/axolotl/integrations/rrt/modeling/linear.py +++ b/src/axolotl/integrations/rrt/modeling/linear.py @@ -1,3 +1,6 @@ +""" +module for the shared linear layer for the relaxed recursive transformers model +""" import math import torch @@ -24,7 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module): self, in_features: int, out_features: int, - B: int, + B: int, # pylint: disable=invalid-name rank: int, alpha: int, fan_in_fan_out: bool = False, @@ -32,7 +35,7 @@ class RelaxedRecursiveDoraLinear(nn.Module): use_dora: bool = True, ): super().__init__() - self.B = B + self.B = B # pylint: disable=invalid-name self.fan_in_fan_out = fan_in_fan_out self.weight_base = nn.Parameter(torch.empty(out_features, in_features)) @@ -43,10 +46,10 @@ class RelaxedRecursiveDoraLinear(nn.Module): else: self.register_parameter("bias", None) - self.lora_A_list = nn.ParameterList( + self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name [nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)] ) - self.lora_B_list = nn.ParameterList( + self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name [nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)] ) # rslora @@ -75,8 +78,12 @@ class RelaxedRecursiveDoraLinear(nn.Module): w_base = self.weight_base w_base = w_base.to(x.dtype) - lora_A: torch.Tensor = self.lora_A_list[loop_idx] - lora_B: torch.Tensor = self.lora_B_list[loop_idx] + lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name + loop_idx + ] + lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name + loop_idx + ] base_out: torch.Tensor = F.linear(x, w_base, self.bias) lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling diff --git a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py index 444dc923f..cded966e2 100644 --- a/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py +++ b/src/axolotl/integrations/rrt/modeling/modeling_rrt_llama.py @@ -1,22 +1,33 @@ import logging -from typing import Tuple, Optional, Unpack, Callable, Union +from typing import Callable, Optional, Tuple, Union, Unpack import torch from torch import nn -from transformers import LlamaConfig, Cache, DynamicCache +from transformers import Cache, DynamicCache, LlamaConfig from transformers.activations import ACT2FN from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \ - LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding +from transformers.models.llama.modeling_llama import ( + LlamaForCausalLM, + LlamaModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + eager_attention_forward, +) from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear + from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig logger = logging.getLogger(__name__) +# pylint: skip-file +# mypy: ignore-errors + + class RelaxedRecursiveLlamaMLP(nn.Module): def __init__(self, config: RelaxedRecursiveLlamaConfig): super().__init__() @@ -24,13 +35,40 @@ class RelaxedRecursiveLlamaMLP(nn.Module): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) - self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) - self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) + self.gate_proj = RelaxedRecursiveDoraLinear( + self.hidden_size, + self.intermediate_size, + recurse_loops, + config.rank, + config.alpha, + bias=config.mlp_bias, + use_dora=config.use_dora, + ) + self.up_proj = RelaxedRecursiveDoraLinear( + self.hidden_size, + self.intermediate_size, + recurse_loops, + config.rank, + config.alpha, + bias=config.mlp_bias, + use_dora=config.use_dora, + ) + self.down_proj = RelaxedRecursiveDoraLinear( + self.intermediate_size, + self.hidden_size, + recurse_loops, + config.rank, + config.alpha, + bias=config.mlp_bias, + use_dora=config.use_dora, + ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x, loop_idx: int): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), loop_idx) + down_proj = self.down_proj( + self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), + loop_idx, + ) return down_proj @@ -44,23 +82,51 @@ class RelaxedRecursiveLlamaAttention(nn.Module): recurse_loops = config.num_hidden_layers // config.recurse_layers self.config = config self.layer_idx = layer_idx - self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) self.scaling = self.head_dim**-0.5 self.attention_dropout = config.attention_dropout self.is_causal = True self.q_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora + config.hidden_size, + config.num_attention_heads * self.head_dim, + recurse_loops, + config.rank, + config.alpha, + bias=config.attention_bias, + use_dora=config.use_dora, ) self.k_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora + config.hidden_size, + config.num_key_value_heads * self.head_dim, + recurse_loops, + config.rank, + config.alpha, + bias=config.attention_bias, + use_dora=config.use_dora, ) self.v_proj = RelaxedRecursiveDoraLinear( - config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora + config.hidden_size, + config.num_key_value_heads * self.head_dim, + recurse_loops, + config.rank, + config.alpha, + bias=config.attention_bias, + use_dora=config.use_dora, ) self.o_proj = RelaxedRecursiveDoraLinear( - config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora + config.num_attention_heads * self.head_dim, + config.hidden_size, + recurse_loops, + config.rank, + config.alpha, + bias=config.attention_bias, + use_dora=config.use_dora, ) def forward( @@ -71,32 +137,46 @@ class RelaxedRecursiveLlamaAttention(nn.Module): loop_idx: int, past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[FlashAttentionKwargs], + **kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + query_states = ( + self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + ) + key_states = ( + self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) + ) cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) if past_key_value is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): - logger.warning_once( + if self.config._attn_implementation == "sdpa" and kwargs.get( + "output_attentions", False + ): + logger.warning( "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ) else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[ + self.config._attn_implementation + ] attn_output, attn_weights = attention_interface( self, @@ -111,8 +191,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module): attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = self.o_proj(attn_output, loop_idx) - return attn_output, attn_weights - + return attn_output, attn_weights # pylint: disable=return-value class RelaxedRecursiveLlamaDecoderLayer(nn.Module): @@ -125,12 +204,24 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module): recurse_loops = config.num_hidden_layers // config.recurse_layers self.hidden_size = config.hidden_size - self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx) + self.self_attn = RelaxedRecursiveLlamaAttention( + config=config, layer_idx=layer_idx + ) self.mlp = RelaxedRecursiveLlamaMLP(config) - self.input_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)]) - self.post_attention_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)]) + self.input_layernorm_list = nn.ModuleList( + [ + LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + for _ in range(recurse_loops) + ] + ) + self.post_attention_layernorm_list = nn.ModuleList( + [ + LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + for _ in range(recurse_loops) + ] + ) def forward( self, @@ -142,9 +233,13 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module): output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # necessary, but kept here for BC + **kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: residual = hidden_states hidden_states = self.input_layernorm_list[loop_idx](hidden_states) @@ -186,9 +281,14 @@ class RelaxedRecursiveLlamaModel(LlamaModel): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) self.layers = nn.ModuleList( - [RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.recurse_layers)] + [ + RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) + for layer_idx in range(config.recurse_layers) + ] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.rotary_emb = LlamaRotaryEmbedding(config=config) @@ -211,15 +311,25 @@ class RelaxedRecursiveLlamaModel(LlamaModel): cache_position: Optional[torch.LongTensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( @@ -234,16 +344,24 @@ class RelaxedRecursiveLlamaModel(LlamaModel): past_key_values = DynamicCache() if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, ) hidden_states = inputs_embeds