feat: add CCE for gemma3, cohere, and cohere2 (#2443)

* feat: add CCE for gemma3 and cohere1/2

* fix: change from relative import to absolute

* feat: add multipack for cohere&cohere2

* chore: improve comments

* fix: add gemma3_text

* feat: add cohere2 example

* fix: cohere forward

* fix: patch for cohere2

* feat: add command r v01 qlora sample

* chore: lint

* feat: upgrade gemma3 and gemma2 patch to use logits_to_keep

* chore: lint

* fix: add deprecate_kwarg decorator

* fix: add cce for gemma3 conditionalgeneration

* fix: gemma3 patch to defer logits calculation

* fix: patch gemma3 if given as model

* fix: remove not working config

* fix: update comments to clarify changes

* feat(doc): add supported models to readme

* fix: address difference in our cohere patch

* feat: add mistral3

* feat: add gemma

* feat(doc): update README to include gemma and mistral3 in supported models

* fix: gemma patch

* fix: import

* fix: gemma patch to be standalone

* fix: gemma3 warn about not support final_logit_softcapping

* feat: add mllama CCE

* chore: add abbireviation to doc

* fix: remove unneeded gemma3 eager warning

* fix: save processor if available

* fix: enable save processor on merge

* fix: wrong env meaning
This commit is contained in:
NanoCode012
2025-03-27 05:13:51 +07:00
committed by GitHub
parent a9b0733f2c
commit 2c34a4634e
16 changed files with 1826 additions and 15 deletions

View File

@@ -0,0 +1,175 @@
"""Gemma CCE patch"""
# This patch is based off transformers 4.50.0.
# pylint: disable=duplicate-code
from types import MethodType
from typing import Optional, Tuple, Union
import torch
import transformers
from cut_cross_entropy.transformers.utils import (
PatchOptions,
TransformersModelT,
apply_lce,
)
from transformers.cache_utils import Cache
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.gemma.modeling_gemma import (
_CONFIG_FOR_DOC,
GEMMA_INPUTS_DOCSTRING,
KwargsForCausalLM,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)
from transformers.utils.deprecation import deprecate_kwarg
_PATCH_OPTS: PatchOptions | None = None
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def cce_forward(
self,
input_ids: torch.LongTensor | None = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, GemmaForCausalLM
>>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b")
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
>>> prompt = "What is your favorite condiment?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"What is your favorite condiment?"
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
loss = None
logits = None
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = (
slice(-logits_to_keep, None)
if isinstance(logits_to_keep, int)
else logits_to_keep
)
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
assert labels is not None
loss = apply_lce(
hidden_states[:, slice_indices, :],
self.lm_head.weight,
labels,
_PATCH_OPTS,
**kwargs,
)
else:
logits = self.lm_head(hidden_states[:, slice_indices, :])
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def patch_gemma(
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
patch_options: PatchOptions,
) -> TransformersModelT | None:
global _PATCH_OPTS # pylint: disable=global-statement
from transformers.models.gemma import modeling_gemma
_PATCH_OPTS = patch_options
if isinstance(maybe_model, transformers.PreTrainedModel):
assert isinstance(
maybe_model, modeling_gemma.GemmaForCausalLM
), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}."
maybe_model.forward = MethodType(cce_forward, maybe_model)
return maybe_model
modeling_gemma.GemmaForCausalLM.forward = cce_forward
return None