chore: lint

This commit is contained in:
Wing Lian
2025-01-24 13:29:54 -05:00
parent 0af78a9882
commit 791c38dcc3
7 changed files with 214 additions and 75 deletions

View File

@@ -4,15 +4,8 @@ Axolotl Plugin for Relaxed Recursive Transformers
import logging import logging
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rrt.modeling import register_rrt_model from axolotl.integrations.rrt.modeling import register_rrt_model
from axolotl.integrations.rrt.modeling.modeling_rrt_llama import (
RelaxedRecursiveLlamaConfig,
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -30,18 +23,3 @@ class RelaxedRecursiveTransformerPlugin(BasePlugin):
"Registering Relaxed Recursive Transformers modeling with transformers" "Registering Relaxed Recursive Transformers modeling with transformers"
) )
register_rrt_model() register_rrt_model()
def register_rrt_model():
"""
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -1,3 +1,7 @@
"""
Axolotl config args for Relaxed Recursive Transformers plugin
"""
from pydantic import BaseModel from pydantic import BaseModel
@@ -5,4 +9,3 @@ class RelaxedRecursiveTransformerArgs(BaseModel):
""" """
Arguments pertaining to the Relaxed Recursive Transformer model. Arguments pertaining to the Relaxed Recursive Transformer model.
""" """
...

View File

@@ -1,3 +1,6 @@
"""
cli script for converting a pretrained model to a relaxed recursive transformer model
"""
import json import json
import logging import logging
import math import math
@@ -52,7 +55,7 @@ def iter_recursive_parameter_weights(
): ):
# setup placeholder state_dict for recursive weights, need to keep in float32 precision # setup placeholder state_dict for recursive weights, need to keep in float32 precision
# to avoid precision loss when averaging weights across layers # to avoid precision loss when averaging weights across layers
rrt_avg_model_state_dict = {} rrt_avg_model_state_dict: dict[str, list[torch.Tensor]] = {}
# iterate over all parameter weights in the model shards # iterate over all parameter weights in the model shards
for key, weight, layer_idx in iter_parameter_weights(model_path, device=device): for key, weight, layer_idx in iter_parameter_weights(model_path, device=device):
@@ -93,6 +96,7 @@ def low_rank_decomposition(
:param max_rank: The maximum rank of the decomposition :param max_rank: The maximum rank of the decomposition
:return: A tuple of tensors (L, R) :return: A tuple of tensors (L, R)
""" """
# pylint: disable=invalid-name
assert ( assert (
weight.dim() == 2 weight.dim() == 2
), f"Only support 2D matrix, but input has {weight.dim()} dimensions." ), f"Only support 2D matrix, but input has {weight.dim()} dimensions."
@@ -135,7 +139,9 @@ def decompose_delta_weight(layer_weight, avg_weight, alpha, rank, use_dora=True)
delta_for_svd = final_weight - base_weight delta_for_svd = final_weight - base_weight
# Low-rank factorization of the delta direction # Low-rank factorization of the delta direction
lora_A, lora_B = low_rank_decomposition(delta_for_svd, rank) lora_A, lora_B = low_rank_decomposition( # pylint: disable=invalid-name
delta_for_svd, rank
)
if use_dora: if use_dora:
lora_weight = lora_B @ lora_A lora_weight = lora_B @ lora_A
@@ -218,6 +224,7 @@ def save_state_dict_to_safetensors(state_dict, save_directory):
state_dict_split = split_torch_state_dict_into_shards( state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size="1GB" state_dict, filename_pattern=filename_pattern, max_shard_size="1GB"
) )
# pylint: disable=duplicate-code
# Save index if sharded # Save index if sharded
index = None index = None
if state_dict_split.is_sharded: if state_dict_split.is_sharded:
@@ -355,7 +362,7 @@ if __name__ == "__main__":
# meta-llama/Llama-3.2-3B has 28 hidden layers # meta-llama/Llama-3.2-3B has 28 hidden layers
convert_llama_to_rrt( convert_llama_to_rrt(
"meta-llama/Llama-3.2-3B", "meta-llama/Llama-3.2-3B",
"/tmp/rrt_model", "/tmp/rrt_model", # nosec
recurse_layers=4, recurse_layers=4,
rank=256, rank=256,
alpha=512, alpha=512,

View File

@@ -1,2 +1,25 @@
"""
module for modeling relaxed recursive transformers model
"""
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
from .modeling_rrt_llama import (
RelaxedRecursiveLlamaForCausalLM,
RelaxedRecursiveLlamaModel,
)
def register_rrt_model(): def register_rrt_model():
pass """
Register Relaxed Recursive Transformers model with transformers
"""
# Register configs
AutoConfig.register("llama-rrt", RelaxedRecursiveLlamaConfig)
# Register models
AutoModel.register(RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaModel)
AutoModelForCausalLM.register(
RelaxedRecursiveLlamaConfig, RelaxedRecursiveLlamaForCausalLM
)

View File

@@ -1,3 +1,6 @@
"""
module for custom configuration for relaxed recursive transformers model
"""
from transformers import LlamaConfig from transformers import LlamaConfig
@@ -6,8 +9,8 @@ class RelaxedRecursiveLlamaConfig(LlamaConfig):
Configuration for Relaxed Recursive Llama. Configuration for Relaxed Recursive Llama.
""" """
model_type = "llama-rrt" model_type: str = "llama-rrt"
recurse_layers: int = 4 recurse_layers: int = 4
rank: int rank: int
alpha: int alpha: int
use_dora: bool = True use_dora: bool = True

View File

@@ -1,3 +1,6 @@
"""
module for the shared linear layer for the relaxed recursive transformers model
"""
import math import math
import torch import torch
@@ -24,7 +27,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
self, self,
in_features: int, in_features: int,
out_features: int, out_features: int,
B: int, B: int, # pylint: disable=invalid-name
rank: int, rank: int,
alpha: int, alpha: int,
fan_in_fan_out: bool = False, fan_in_fan_out: bool = False,
@@ -32,7 +35,7 @@ class RelaxedRecursiveDoraLinear(nn.Module):
use_dora: bool = True, use_dora: bool = True,
): ):
super().__init__() super().__init__()
self.B = B self.B = B # pylint: disable=invalid-name
self.fan_in_fan_out = fan_in_fan_out self.fan_in_fan_out = fan_in_fan_out
self.weight_base = nn.Parameter(torch.empty(out_features, in_features)) self.weight_base = nn.Parameter(torch.empty(out_features, in_features))
@@ -43,10 +46,10 @@ class RelaxedRecursiveDoraLinear(nn.Module):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
self.lora_A_list = nn.ParameterList( self.lora_A_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)] [nn.Parameter(torch.zeros(rank, in_features)) for _ in range(B)]
) )
self.lora_B_list = nn.ParameterList( self.lora_B_list = nn.ParameterList( # pylint: disable=invalid-name
[nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)] [nn.Parameter(torch.zeros(out_features, rank)) for _ in range(B)]
) )
# rslora # rslora
@@ -75,8 +78,12 @@ class RelaxedRecursiveDoraLinear(nn.Module):
w_base = self.weight_base w_base = self.weight_base
w_base = w_base.to(x.dtype) w_base = w_base.to(x.dtype)
lora_A: torch.Tensor = self.lora_A_list[loop_idx] lora_A: torch.Tensor = self.lora_A_list[ # pylint: disable=invalid-name
lora_B: torch.Tensor = self.lora_B_list[loop_idx] loop_idx
]
lora_B: torch.Tensor = self.lora_B_list[ # pylint: disable=invalid-name
loop_idx
]
base_out: torch.Tensor = F.linear(x, w_base, self.bias) base_out: torch.Tensor = F.linear(x, w_base, self.bias)
lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling lora_out: torch.Tensor = F.linear(F.linear(x, lora_A), lora_B) * self.scaling

View File

@@ -1,22 +1,33 @@
import logging import logging
from typing import Tuple, Optional, Unpack, Callable, Union from typing import Callable, Optional, Tuple, Union, Unpack
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig, Cache, DynamicCache from transformers import Cache, DynamicCache, LlamaConfig
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, eager_attention_forward, LlamaRMSNorm, \ from transformers.models.llama.modeling_llama import (
LlamaForCausalLM, LlamaModel, LlamaRotaryEmbedding LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear from axolotl.integrations.rrt.modeling.linear import RelaxedRecursiveDoraLinear
from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig from .configuration_rrt_llama import RelaxedRecursiveLlamaConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# pylint: skip-file
# mypy: ignore-errors
class RelaxedRecursiveLlamaMLP(nn.Module): class RelaxedRecursiveLlamaMLP(nn.Module):
def __init__(self, config: RelaxedRecursiveLlamaConfig): def __init__(self, config: RelaxedRecursiveLlamaConfig):
super().__init__() super().__init__()
@@ -24,13 +35,40 @@ class RelaxedRecursiveLlamaMLP(nn.Module):
self.config = config self.config = config
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size self.intermediate_size = config.intermediate_size
self.gate_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) self.gate_proj = RelaxedRecursiveDoraLinear(
self.up_proj = RelaxedRecursiveDoraLinear(self.hidden_size, self.intermediate_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) self.hidden_size,
self.down_proj = RelaxedRecursiveDoraLinear(self.intermediate_size, self.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.mlp_bias, use_dora=config.use_dora) self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.up_proj = RelaxedRecursiveDoraLinear(
self.hidden_size,
self.intermediate_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.down_proj = RelaxedRecursiveDoraLinear(
self.intermediate_size,
self.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.mlp_bias,
use_dora=config.use_dora,
)
self.act_fn = ACT2FN[config.hidden_act] self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x, loop_idx: int): def forward(self, x, loop_idx: int):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx), loop_idx) down_proj = self.down_proj(
self.act_fn(self.gate_proj(x, loop_idx)) * self.up_proj(x, loop_idx),
loop_idx,
)
return down_proj return down_proj
@@ -44,23 +82,51 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
recurse_loops = config.num_hidden_layers // config.recurse_layers recurse_loops = config.num_hidden_layers // config.recurse_layers
self.config = config self.config = config
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) self.head_dim = getattr(
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout self.attention_dropout = config.attention_dropout
self.is_causal = True self.is_causal = True
self.q_proj = RelaxedRecursiveDoraLinear( self.q_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_attention_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora config.hidden_size,
config.num_attention_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
) )
self.k_proj = RelaxedRecursiveDoraLinear( self.k_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
) )
self.v_proj = RelaxedRecursiveDoraLinear( self.v_proj = RelaxedRecursiveDoraLinear(
config.hidden_size, config.num_key_value_heads * self.head_dim, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora config.hidden_size,
config.num_key_value_heads * self.head_dim,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
) )
self.o_proj = RelaxedRecursiveDoraLinear( self.o_proj = RelaxedRecursiveDoraLinear(
config.num_attention_heads * self.head_dim, config.hidden_size, recurse_loops, config.rank, config.alpha, bias=config.attention_bias, use_dora=config.use_dora config.num_attention_heads * self.head_dim,
config.hidden_size,
recurse_loops,
config.rank,
config.alpha,
bias=config.attention_bias,
use_dora=config.use_dora,
) )
def forward( def forward(
@@ -71,32 +137,46 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
loop_idx: int, loop_idx: int,
past_key_value: Optional[Cache] = None, past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs], **kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1] input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim) hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) query_states = (
key_states = self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) self.q_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2) )
key_states = (
self.k_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states, loop_idx).view(hidden_shape).transpose(1, 2)
)
cos, sin = position_embeddings cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None: if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache # sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager": if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): if self.config._attn_implementation == "sdpa" and kwargs.get(
logger.warning_once( "output_attentions", False
):
logger.warning(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
) )
else: else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface( attn_output, attn_weights = attention_interface(
self, self,
@@ -111,8 +191,7 @@ class RelaxedRecursiveLlamaAttention(nn.Module):
attn_output = attn_output.reshape(*input_shape, -1).contiguous() attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output, loop_idx) attn_output = self.o_proj(attn_output, loop_idx)
return attn_output, attn_weights return attn_output, attn_weights # pylint: disable=return-value
class RelaxedRecursiveLlamaDecoderLayer(nn.Module): class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
@@ -125,12 +204,24 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
recurse_loops = config.num_hidden_layers // config.recurse_layers recurse_loops = config.num_hidden_layers // config.recurse_layers
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = RelaxedRecursiveLlamaAttention(config=config, layer_idx=layer_idx) self.self_attn = RelaxedRecursiveLlamaAttention(
config=config, layer_idx=layer_idx
)
self.mlp = RelaxedRecursiveLlamaMLP(config) self.mlp = RelaxedRecursiveLlamaMLP(config)
self.input_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)]) self.input_layernorm_list = nn.ModuleList(
self.post_attention_layernorm_list = nn.ModuleList([LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) for _ in range(recurse_loops)]) [
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
self.post_attention_layernorm_list = nn.ModuleList(
[
LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
for _ in range(recurse_loops)
]
)
def forward( def forward(
self, self,
@@ -142,9 +233,13 @@ class RelaxedRecursiveLlamaDecoderLayer(nn.Module):
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False, use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC position_embeddings: Optional[
**kwargs: Unpack[FlashAttentionKwargs], Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ] = None, # necessary, but kept here for BC
**kwargs: Unpack[FlashAttentionKwargs], # pylint: disable=misc
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm_list[loop_idx](hidden_states) hidden_states = self.input_layernorm_list[loop_idx](hidden_states)
@@ -186,9 +281,14 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
self.padding_idx = config.pad_token_id self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[RelaxedRecursiveLlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.recurse_layers)] [
RelaxedRecursiveLlamaDecoderLayer(config, layer_idx)
for layer_idx in range(config.recurse_layers)
]
) )
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config) self.rotary_emb = LlamaRotaryEmbedding(config=config)
@@ -211,15 +311,25 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
cache_position: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs], **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPast]: ) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
) )
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if self.gradient_checkpointing and self.training and use_cache: if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once( logger.warning_once(
@@ -234,16 +344,24 @@ class RelaxedRecursiveLlamaModel(LlamaModel):
past_key_values = DynamicCache() past_key_values = DynamicCache()
if cache_position is None: if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange( cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
device=inputs_embeds.device,
) )
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask,
inputs_embeds,
cache_position,
past_key_values,
output_attentions,
) )
hidden_states = inputs_embeds hidden_states = inputs_embeds