"""Llama CCE patch. Adapted from transformers v4.51.2""" # pylint: disable=duplicate-code from types import MethodType from typing import Optional, Union import torch import transformers from cut_cross_entropy.transformers.utils import ( PatchOptions, TransformersModelT, apply_lce, ) from transformers.cache_utils import Cache from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers.models.llama.modeling_llama import ( KwargsForCausalLM, ) from transformers.processing_utils import Unpack from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") def cce_forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[KwargsForCausalLM], ) -> CausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. logits_to_keep (`int` or `torch.Tensor`, *optional*): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. This is useful when using packed tensor format (single dimension for batch and sequence length). Returns: Example: ```python >>> from transformers import AutoTokenizer, LlamaForCausalLM >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") >>> prompt = "Hey, are you conscious? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs: BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, cache_position=cache_position, **kwargs, ) hidden_states = outputs.last_hidden_state if hidden_states is None: raise ValueError("hidden_states is None") loss = None logits = None # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = ( slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep ) if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): assert labels is not None loss = apply_lce( hidden_states[:, slice_indices, :], self.lm_head.weight, labels, _PATCH_OPTS, **kwargs, ) else: logits = self.lm_head(hidden_states[:, slice_indices, :]) if labels is not None: loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs, ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def patch_llama( maybe_model: TransformersModelT | str | transformers.PretrainedConfig, patch_options: PatchOptions, ) -> TransformersModelT | None: """Patch Llama for CCE.""" global _PATCH_OPTS # pylint: disable=global-statement from transformers.models.llama import modeling_llama _PATCH_OPTS = patch_options if isinstance(maybe_model, transformers.PreTrainedModel): assert isinstance( maybe_model, modeling_llama.LlamaForCausalLM ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." maybe_model.forward = MethodType(cce_forward, maybe_model) return maybe_model modeling_llama.LlamaForCausalLM.forward = cce_forward return None