Clean up landmark patching
This commit is contained in:
@@ -28,15 +28,23 @@ from typing import List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.activations import ACT2FN
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
SequenceClassifierOutputWithPast,
|
|
||||||
)
|
)
|
||||||
from transformers.modeling_utils import PreTrainedModel
|
|
||||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
LLAMA_INPUTS_DOCSTRING,
|
||||||
|
LLAMA_START_DOCSTRING,
|
||||||
|
LlamaMLP,
|
||||||
|
LlamaPreTrainedModel,
|
||||||
|
LlamaRMSNorm,
|
||||||
|
LlamaRotaryEmbedding,
|
||||||
|
_expand_mask,
|
||||||
|
_make_causal_mask,
|
||||||
|
rotate_half,
|
||||||
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
add_start_docstrings_to_model_forward,
|
add_start_docstrings_to_model_forward,
|
||||||
@@ -51,131 +59,6 @@ _CONFIG_FOR_DOC = "LlamaConfig"
|
|||||||
MEM_TOKEN = "<landmark>" # nosec
|
MEM_TOKEN = "<landmark>" # nosec
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
||||||
def _make_causal_mask(
|
|
||||||
input_ids_shape: torch.Size,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
past_key_values_length: int = 0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Make causal mask used for bi-directional self-attention.
|
|
||||||
"""
|
|
||||||
bsz, tgt_len = input_ids_shape
|
|
||||||
mask = torch.full(
|
|
||||||
(tgt_len, tgt_len),
|
|
||||||
torch.tensor(torch.finfo(dtype).min, device=device),
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
||||||
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
||||||
mask = mask.to(dtype)
|
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
|
||||||
mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
mask,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
return mask[None, None, :, :].expand(
|
|
||||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._expand_mask
|
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
||||||
"""
|
|
||||||
bsz, src_len = mask.size()
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
||||||
|
|
||||||
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
|
||||||
"""
|
|
||||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
|
||||||
"""
|
|
||||||
super().__init__()
|
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
|
||||||
self.variance_epsilon = eps
|
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
|
||||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
|
||||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
||||||
|
|
||||||
# convert into half-precision if necessary
|
|
||||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
|
||||||
|
|
||||||
return self.weight * hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaRotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
|
||||||
super().__init__()
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
|
||||||
|
|
||||||
# Build here to make `torch.jit.trace` work.
|
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
|
||||||
t = torch.arange(
|
|
||||||
self.max_seq_len_cached,
|
|
||||||
device=self.inv_freq.device,
|
|
||||||
dtype=self.inv_freq.dtype,
|
|
||||||
)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
self.register_buffer(
|
|
||||||
"cos_cached", emb.cos()[None, None, :, :], persistent=False
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sin_cached", emb.sin()[None, None, :, :], persistent=False
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
|
||||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(
|
|
||||||
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
|
|
||||||
)
|
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
|
||||||
self.register_buffer(
|
|
||||||
"cos_cached", emb.cos()[None, None, :, :], persistent=False
|
|
||||||
)
|
|
||||||
self.register_buffer(
|
|
||||||
"sin_cached", emb.sin()[None, None, :, :], persistent=False
|
|
||||||
)
|
|
||||||
return (
|
|
||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
|
||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
"""Rotates half the hidden dims of the input."""
|
|
||||||
x1 = x[..., : x.shape[-1] // 2]
|
|
||||||
x2 = x[..., x.shape[-1] // 2 :]
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||||
@@ -190,24 +73,11 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
intermediate_size: int,
|
|
||||||
hidden_act: str,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
|
||||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
|
||||||
self.act_fn = ACT2FN[hidden_act]
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
|
||||||
|
|
||||||
|
|
||||||
class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
|
class LandmarkGroupedSoftmaxFunction(torch.autograd.Function):
|
||||||
|
"""
|
||||||
|
Landmark grouped softmax function.
|
||||||
|
"""
|
||||||
|
|
||||||
# Note that forward, setup_context, and backward are @staticmethods
|
# Note that forward, setup_context, and backward are @staticmethods
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
|
def forward(ctx, x, dim, mem_cnt, resp_mem_idx):
|
||||||
@@ -682,16 +552,14 @@ class LlamaAttention(nn.Module):
|
|||||||
# upcast attention to fp32
|
# upcast attention to fp32
|
||||||
if is_mem is None:
|
if is_mem is None:
|
||||||
raise ValueError("Don't use this without landmarks")
|
raise ValueError("Don't use this without landmarks")
|
||||||
# attn_weights = nn.functional.softmax(
|
|
||||||
# attn_weights, dim=-1, dtype=torch.float32
|
attn_weights = landmark_grouped_softmax(
|
||||||
# ).to(query_states.dtype)
|
attn_weights,
|
||||||
else:
|
dim=-1,
|
||||||
attn_weights = landmark_grouped_softmax(
|
is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
|
||||||
attn_weights,
|
last_section_mask=last_section_mask,
|
||||||
dim=-1,
|
).to(query_states.dtype)
|
||||||
is_mem=is_mem.expand(-1, self.num_heads, -1, -1),
|
|
||||||
last_section_mask=last_section_mask,
|
|
||||||
).to(query_states.dtype)
|
|
||||||
if attn_prefix is not None:
|
if attn_prefix is not None:
|
||||||
attn_prefix, attn_weights = torch.split(
|
attn_prefix, attn_weights = torch.split(
|
||||||
attn_weights,
|
attn_weights,
|
||||||
@@ -722,6 +590,10 @@ class LlamaAttention(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaDecoderLayer(nn.Module):
|
class LlamaDecoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
Llama Decoder layer
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: LlamaConfig):
|
def __init__(self, config: LlamaConfig):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
@@ -802,114 +674,6 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
LLAMA_START_DOCSTRING = r"""
|
|
||||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
||||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
||||||
etc.)
|
|
||||||
|
|
||||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
||||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
||||||
and behavior.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
config ([`LlamaConfig`]):
|
|
||||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
|
||||||
load the weights associated with the model, only the configuration. Check out the
|
|
||||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
|
||||||
LLAMA_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class LlamaPreTrainedModel(PreTrainedModel):
|
|
||||||
config_class = LlamaConfig
|
|
||||||
base_model_prefix = "model"
|
|
||||||
supports_gradient_checkpointing = True
|
|
||||||
_no_split_modules = ["LlamaDecoderLayer"]
|
|
||||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
|
||||||
std = self.config.initializer_range
|
|
||||||
if isinstance(module, nn.Linear):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.bias is not None:
|
|
||||||
module.bias.data.zero_()
|
|
||||||
elif isinstance(module, nn.Embedding):
|
|
||||||
module.weight.data.normal_(mean=0.0, std=std)
|
|
||||||
if module.padding_idx is not None:
|
|
||||||
module.weight.data[module.padding_idx].zero_()
|
|
||||||
|
|
||||||
def _set_gradient_checkpointing(self, module, value=False):
|
|
||||||
if isinstance(module, LlamaModel):
|
|
||||||
module.gradient_checkpointing = value
|
|
||||||
|
|
||||||
|
|
||||||
LLAMA_INPUTS_DOCSTRING = r"""
|
|
||||||
Args:
|
|
||||||
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
||||||
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
||||||
it.
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
|
||||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
||||||
|
|
||||||
- 1 for tokens that are **not masked**,
|
|
||||||
- 0 for tokens that are **masked**.
|
|
||||||
|
|
||||||
[What are attention masks?](../glossary#attention-mask)
|
|
||||||
|
|
||||||
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
||||||
[`PreTrainedTokenizer.__call__`] for details.
|
|
||||||
|
|
||||||
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
|
|
||||||
`past_key_values`).
|
|
||||||
|
|
||||||
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
|
|
||||||
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
|
|
||||||
information on the default strategy.
|
|
||||||
|
|
||||||
- 1 indicates the head is **not masked**,
|
|
||||||
- 0 indicates the head is **masked**.
|
|
||||||
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
||||||
config.n_positions - 1]`.
|
|
||||||
|
|
||||||
[What are position IDs?](../glossary#position-ids)
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
|
||||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
||||||
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
|
||||||
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
|
||||||
model's internal embedding lookup matrix.
|
|
||||||
use_cache (`bool`, *optional*):
|
|
||||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
|
||||||
`past_key_values`).
|
|
||||||
output_attentions (`bool`, *optional*):
|
|
||||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
||||||
tensors for more detail.
|
|
||||||
output_hidden_states (`bool`, *optional*):
|
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
||||||
more detail.
|
|
||||||
return_dict (`bool`, *optional*):
|
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
"The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
|
||||||
LLAMA_START_DOCSTRING,
|
LLAMA_START_DOCSTRING,
|
||||||
@@ -1178,6 +942,10 @@ class LlamaModel(LlamaPreTrainedModel):
|
|||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLM(LlamaPreTrainedModel):
|
class LlamaForCausalLM(LlamaPreTrainedModel):
|
||||||
|
"""
|
||||||
|
Llama model with a causal language modeling head.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.model = LlamaModel(config)
|
self.model = LlamaModel(config)
|
||||||
@@ -1448,149 +1216,15 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
|
|||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
|
||||||
"""
|
|
||||||
The LLaMa Model transformer with a sequence classification head on top (linear layer).
|
|
||||||
|
|
||||||
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
|
||||||
(e.g. GPT-2) do.
|
|
||||||
|
|
||||||
Since it does classification on the last token, it requires to know the position of the last token. If a
|
|
||||||
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
|
||||||
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
|
||||||
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
|
||||||
each row of the batch).
|
|
||||||
""",
|
|
||||||
LLAMA_START_DOCSTRING,
|
|
||||||
)
|
|
||||||
class LlamaForSequenceClassification(LlamaPreTrainedModel):
|
|
||||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__(config)
|
|
||||||
self.num_labels = config.num_labels
|
|
||||||
self.model = LlamaModel(config)
|
|
||||||
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
|
||||||
|
|
||||||
# Initialize weights and apply final processing
|
|
||||||
self.post_init()
|
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
|
||||||
return self.model.embed_tokens
|
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
|
||||||
self.model.embed_tokens = value
|
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
||||||
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
|
||||||
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
|
||||||
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
|
||||||
"""
|
|
||||||
return_dict = (
|
|
||||||
return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
transformer_outputs = self.model(
|
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
hidden_states = transformer_outputs[0]
|
|
||||||
logits = self.score(hidden_states)
|
|
||||||
|
|
||||||
if input_ids is not None:
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
else:
|
|
||||||
batch_size = inputs_embeds.shape[0]
|
|
||||||
|
|
||||||
if self.config.pad_token_id is None and batch_size != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot handle batch sizes > 1 if no padding token is defined."
|
|
||||||
)
|
|
||||||
if self.config.pad_token_id is None:
|
|
||||||
sequence_lengths = -1
|
|
||||||
else:
|
|
||||||
if input_ids is not None:
|
|
||||||
sequence_lengths = (
|
|
||||||
torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
|
|
||||||
).to(logits.device)
|
|
||||||
else:
|
|
||||||
sequence_lengths = -1
|
|
||||||
|
|
||||||
pooled_logits = logits[
|
|
||||||
torch.arange(batch_size, device=logits.device), sequence_lengths
|
|
||||||
]
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
labels = labels.to(logits.device)
|
|
||||||
if self.config.problem_type is None:
|
|
||||||
if self.num_labels == 1:
|
|
||||||
self.config.problem_type = "regression"
|
|
||||||
elif self.num_labels > 1 and (
|
|
||||||
labels.dtype == torch.long or labels.dtype == torch.int
|
|
||||||
):
|
|
||||||
self.config.problem_type = "single_label_classification"
|
|
||||||
else:
|
|
||||||
self.config.problem_type = "multi_label_classification"
|
|
||||||
|
|
||||||
if self.config.problem_type == "regression":
|
|
||||||
loss_fct = MSELoss()
|
|
||||||
if self.num_labels == 1:
|
|
||||||
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
|
||||||
else:
|
|
||||||
loss = loss_fct(pooled_logits, labels)
|
|
||||||
elif self.config.problem_type == "single_label_classification":
|
|
||||||
loss_fct = CrossEntropyLoss()
|
|
||||||
loss = loss_fct(
|
|
||||||
pooled_logits.view(-1, self.num_labels), labels.view(-1)
|
|
||||||
)
|
|
||||||
elif self.config.problem_type == "multi_label_classification":
|
|
||||||
loss_fct = BCEWithLogitsLoss()
|
|
||||||
loss = loss_fct(pooled_logits, labels)
|
|
||||||
if not return_dict:
|
|
||||||
output = (pooled_logits,) + transformer_outputs[1:]
|
|
||||||
return ((loss,) + output) if loss is not None else output
|
|
||||||
|
|
||||||
return SequenceClassifierOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=pooled_logits,
|
|
||||||
past_key_values=transformer_outputs.past_key_values,
|
|
||||||
hidden_states=transformer_outputs.hidden_states,
|
|
||||||
attentions=transformer_outputs.attentions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_mem_tokens(example, mem_freq, mem_id):
|
def add_mem_tokens(example, mem_freq, mem_id):
|
||||||
x = example["input_ids"]
|
ids = example["input_ids"]
|
||||||
ret = []
|
ret = []
|
||||||
prev_idx = 0
|
prev_idx = 0
|
||||||
for t_idx in range(mem_freq, len(x), mem_freq):
|
for t_idx in range(mem_freq, len(ids), mem_freq):
|
||||||
ret.extend(x[prev_idx:t_idx])
|
ret.extend(ids[prev_idx:t_idx])
|
||||||
ret.append(mem_id)
|
ret.append(mem_id)
|
||||||
prev_idx = t_idx
|
prev_idx = t_idx
|
||||||
ret.extend(x[prev_idx:])
|
ret.extend(ids[prev_idx:])
|
||||||
# drop attention_mask
|
# drop attention_mask
|
||||||
return {"input_ids": ret}
|
return {"input_ids": ret}
|
||||||
|
|
||||||
@@ -1602,3 +1236,4 @@ def patch_llama_with_landmark_attn():
|
|||||||
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
transformers.models.llama.modeling_llama.LlamaModel = LlamaModel
|
||||||
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
||||||
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
||||||
|
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
||||||
|
|||||||
Reference in New Issue
Block a user