165 lines
5.8 KiB
Python
165 lines
5.8 KiB
Python
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
|
|
|
# pylint: disable=duplicate-code
|
|
|
|
|
|
from types import MethodType
|
|
from typing import Optional, Union
|
|
|
|
import torch
|
|
import transformers
|
|
from cut_cross_entropy.transformers.utils import (
|
|
PatchOptions,
|
|
TransformersModelT,
|
|
apply_lce,
|
|
)
|
|
from transformers.cache_utils import Cache
|
|
from transformers.modeling_outputs import (
|
|
BaseModelOutputWithPast,
|
|
CausalLMOutputWithPast,
|
|
)
|
|
from transformers.models.llama.modeling_llama import (
|
|
KwargsForCausalLM,
|
|
)
|
|
from transformers.processing_utils import Unpack
|
|
from transformers.utils.deprecation import deprecate_kwarg
|
|
from transformers.utils.generic import can_return_tuple
|
|
|
|
_PATCH_OPTS: PatchOptions | None = None
|
|
|
|
|
|
@can_return_tuple
|
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
|
def cce_forward(
|
|
self,
|
|
input_ids: Optional[torch.LongTensor] = None,
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
past_key_values: Optional[Cache] = None,
|
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
labels: Optional[torch.LongTensor] = None,
|
|
use_cache: Optional[bool] = None,
|
|
output_attentions: Optional[bool] = None,
|
|
output_hidden_states: Optional[bool] = None,
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
|
**kwargs: Unpack[KwargsForCausalLM],
|
|
) -> CausalLMOutputWithPast:
|
|
r"""
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
|
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
|
|
|
Returns:
|
|
|
|
Example:
|
|
|
|
```python
|
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
|
|
|
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
|
|
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
>>> # Generate
|
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
|
```"""
|
|
output_attentions = (
|
|
output_attentions
|
|
if output_attentions is not None
|
|
else self.config.output_attentions
|
|
)
|
|
output_hidden_states = (
|
|
output_hidden_states
|
|
if output_hidden_states is not None
|
|
else self.config.output_hidden_states
|
|
)
|
|
|
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
outputs: BaseModelOutputWithPast = self.model(
|
|
input_ids=input_ids,
|
|
attention_mask=attention_mask,
|
|
position_ids=position_ids,
|
|
past_key_values=past_key_values,
|
|
inputs_embeds=inputs_embeds,
|
|
use_cache=use_cache,
|
|
output_attentions=output_attentions,
|
|
output_hidden_states=output_hidden_states,
|
|
cache_position=cache_position,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states = outputs.last_hidden_state
|
|
if hidden_states is None:
|
|
raise ValueError("hidden_states is None")
|
|
|
|
loss = None
|
|
logits = None
|
|
|
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
|
slice_indices = (
|
|
slice(-logits_to_keep, None)
|
|
if isinstance(logits_to_keep, int)
|
|
else logits_to_keep
|
|
)
|
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
|
assert labels is not None
|
|
loss = apply_lce(
|
|
hidden_states[:, slice_indices, :],
|
|
self.lm_head.weight,
|
|
labels,
|
|
_PATCH_OPTS,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
|
|
|
if labels is not None:
|
|
loss = self.loss_function(
|
|
logits=logits,
|
|
labels=labels,
|
|
vocab_size=self.config.vocab_size,
|
|
**kwargs,
|
|
)
|
|
|
|
return CausalLMOutputWithPast(
|
|
loss=loss,
|
|
logits=logits,
|
|
past_key_values=outputs.past_key_values,
|
|
hidden_states=outputs.hidden_states,
|
|
attentions=outputs.attentions,
|
|
)
|
|
|
|
|
|
def patch_llama(
|
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
|
patch_options: PatchOptions,
|
|
) -> TransformersModelT | None:
|
|
"""Patch Llama for CCE."""
|
|
global _PATCH_OPTS # pylint: disable=global-statement
|
|
from transformers.models.llama import modeling_llama
|
|
|
|
_PATCH_OPTS = patch_options
|
|
|
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
|
assert isinstance(
|
|
maybe_model, modeling_llama.LlamaForCausalLM
|
|
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
|
return maybe_model
|
|
|
|
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
|
return None
|