Feat: Add qwen3 and CCE for qwen family (#2518)
This commit is contained in:
68
examples/qwen3/qlora-fsdp.yaml
Normal file
68
examples/qwen3/qlora-fsdp.yaml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
base_model: Qwen/Qwen3-8B
|
||||||
|
# Automatically upload checkpoint and final model to HF
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./outputs/out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: true
|
||||||
|
eval_sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 64
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_torch_fused
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
bf16: auto
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
gradient_checkpointing_kwargs:
|
||||||
|
use_reentrant: false
|
||||||
|
resume_from_checkpoint:
|
||||||
|
logging_steps: 1
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
saves_per_epoch: 1
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
- full_shard
|
||||||
|
- auto_wrap
|
||||||
|
fsdp_config:
|
||||||
|
fsdp_limit_all_gathers: true
|
||||||
|
fsdp_sync_module_states: true
|
||||||
|
fsdp_offload_params: true
|
||||||
|
fsdp_use_orig_params: false
|
||||||
|
fsdp_cpu_ram_efficient_loading: true
|
||||||
|
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||||
|
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
|
||||||
|
fsdp_state_dict_type: FULL_STATE_DICT
|
||||||
|
fsdp_sharding_strategy: FULL_SHARD
|
||||||
|
special_tokens:
|
||||||
@@ -32,8 +32,8 @@ plugins:
|
|||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- llama
|
- llama
|
||||||
- llama4_text
|
|
||||||
- llama4
|
- llama4
|
||||||
|
- llama4_text
|
||||||
- mllama
|
- mllama
|
||||||
- phi3
|
- phi3
|
||||||
- gemma
|
- gemma
|
||||||
@@ -43,6 +43,11 @@ plugins:
|
|||||||
- mistral
|
- mistral
|
||||||
- mistral3
|
- mistral3
|
||||||
- qwen2
|
- qwen2
|
||||||
|
- qwen2_moe
|
||||||
|
- qwen2_vl
|
||||||
|
- qwen2_5_vl
|
||||||
|
- qwen3
|
||||||
|
- qwen3_moe
|
||||||
- cohere
|
- cohere
|
||||||
- cohere2
|
- cohere2
|
||||||
- glm
|
- glm
|
||||||
|
|||||||
174
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py
Normal file
174
src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Llama CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.cache_utils import Cache
|
||||||
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.models.llama.modeling_llama import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
LLAMA_INPUTS_DOCSTRING,
|
||||||
|
KwargsForCausalLM,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[Cache] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> CausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
|
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs: BaseModelOutputWithPast = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states is None")
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(
|
||||||
|
logits=logits,
|
||||||
|
labels=labels,
|
||||||
|
vocab_size=self.config.vocab_size,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return CausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_llama(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
"""Patch Llama for CCE."""
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
from transformers.models.llama import modeling_llama
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_llama.LlamaForCausalLM
|
||||||
|
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_llama.LlamaForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -5,9 +5,7 @@
|
|||||||
import transformers
|
import transformers
|
||||||
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
|
||||||
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
|
||||||
from cut_cross_entropy.transformers.llama import patch_llama
|
|
||||||
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
from cut_cross_entropy.transformers.phi3 import patch_phi3
|
||||||
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
|
|
||||||
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
|
||||||
|
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
|
||||||
@@ -24,6 +22,9 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import (
|
|||||||
patch_glm,
|
patch_glm,
|
||||||
patch_glm4,
|
patch_glm4,
|
||||||
)
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||||
|
patch_llama,
|
||||||
|
)
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
|
||||||
patch_llama4,
|
patch_llama4,
|
||||||
patch_llama4_text,
|
patch_llama4_text,
|
||||||
@@ -33,6 +34,22 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import (
|
|||||||
patch_mistral3,
|
patch_mistral3,
|
||||||
)
|
)
|
||||||
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
|
||||||
|
patch_qwen2,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
|
||||||
|
patch_qwen2_5_vl,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
|
||||||
|
patch_qwen2_moe,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
|
||||||
|
patch_qwen2_vl,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
|
||||||
|
patch_qwen3_moe,
|
||||||
|
)
|
||||||
|
|
||||||
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
||||||
"llama": patch_llama,
|
"llama": patch_llama,
|
||||||
@@ -47,6 +64,11 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = {
|
|||||||
"mistral": patch_mistral,
|
"mistral": patch_mistral,
|
||||||
"mistral3": patch_mistral3,
|
"mistral3": patch_mistral3,
|
||||||
"qwen2": patch_qwen2,
|
"qwen2": patch_qwen2,
|
||||||
|
"qwen2_moe": patch_qwen2_moe,
|
||||||
|
"qwen2_vl": patch_qwen2_vl,
|
||||||
|
"qwen2_5_vl": patch_qwen2_5_vl,
|
||||||
|
"qwen3": patch_qwen3,
|
||||||
|
"qwen3_moe": patch_qwen3_moe,
|
||||||
"cohere": patch_cohere,
|
"cohere": patch_cohere,
|
||||||
"cohere2": patch_cohere2,
|
"cohere2": patch_cohere2,
|
||||||
"glm": patch_glm,
|
"glm": patch_glm,
|
||||||
|
|||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
from transformers.models.qwen2 import modeling_qwen2
|
||||||
|
|
||||||
|
# Set the _PATCH_OPTS in the llama patch file
|
||||||
|
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||||
|
|
||||||
|
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
|
||||||
|
cce_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2.Qwen2ForCausalLM
|
||||||
|
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,246 @@
|
|||||||
|
"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||||
|
Qwen2_5_VLCausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.type(self.visual.dtype)
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
|
n_image_features = image_embeds.shape[0]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = input_ids == self.config.image_token_id
|
||||||
|
mask_unsqueezed = mask.unsqueeze(-1)
|
||||||
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
|
image_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
|
n_video_features = video_embeds.shape[0]
|
||||||
|
if n_video_tokens != n_video_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = input_ids == self.config.video_token_id
|
||||||
|
mask_unsqueezed = mask.unsqueeze(-1)
|
||||||
|
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
|
||||||
|
video_mask = mask_expanded.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||||
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
|
# calculate RoPE index once per generation in the pre-fill stage only
|
||||||
|
if (
|
||||||
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
|
or self.rope_deltas is None
|
||||||
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||||
|
):
|
||||||
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
|
input_ids,
|
||||||
|
image_grid_thw,
|
||||||
|
video_grid_thw,
|
||||||
|
second_per_grid_ts,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
self.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
delta = (
|
||||||
|
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
|
||||||
|
if cache_position is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||||
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||||
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||||
|
position_ids = position_ids.add(delta) # type: ignore
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Qwen2_5_VLCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
rope_deltas=self.rope_deltas,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2_5_vl(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration
|
||||||
|
), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = (
|
||||||
|
cce_forward_multimodal
|
||||||
|
)
|
||||||
|
return None
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.models.qwen2_moe.modeling_qwen2_moe import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
QWEN2MOE_INPUTS_DOCSTRING,
|
||||||
|
MoeCausalLMOutputWithPast,
|
||||||
|
MoeModelOutputWithPast,
|
||||||
|
load_balancing_loss_func,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**loss_kwargs,
|
||||||
|
) -> MoeCausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM
|
||||||
|
|
||||||
|
>>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs: MoeModelOutputWithPast = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states is None")
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**loss_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits,
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||||
|
loss.device # type: ignore
|
||||||
|
) # make sure to reside in the same device
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss, # type: ignore
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
router_logits=outputs.router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2_moe(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen2_moe import modeling_qwen2_moe
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM
|
||||||
|
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(forward, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,249 @@
|
|||||||
|
"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
QWEN2_VL_INPUTS_DOCSTRING,
|
||||||
|
Qwen2VLCausalLMOutputWithPast,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def cce_forward_multimodal(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
pixel_values: Optional[torch.Tensor] = None,
|
||||||
|
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
||||||
|
image_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
video_grid_thw: Optional[torch.LongTensor] = None,
|
||||||
|
rope_deltas: Optional[torch.LongTensor] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
||||||
|
|
||||||
|
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
||||||
|
|
||||||
|
>>> messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"type": "image"},
|
||||||
|
{"type": "text", "text": "What is shown in this image?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.model.embed_tokens(input_ids)
|
||||||
|
if pixel_values is not None:
|
||||||
|
pixel_values = pixel_values.type(self.visual.get_dtype())
|
||||||
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
||||||
|
n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
|
||||||
|
n_image_features = image_embeds.shape[0]
|
||||||
|
if n_image_tokens != n_image_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||||
|
)
|
||||||
|
image_mask = (
|
||||||
|
(input_ids == self.config.image_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore
|
||||||
|
|
||||||
|
if pixel_values_videos is not None:
|
||||||
|
pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype())
|
||||||
|
video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
|
||||||
|
n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
|
||||||
|
n_video_features = video_embeds.shape[0]
|
||||||
|
if n_video_tokens != n_video_features:
|
||||||
|
raise ValueError(
|
||||||
|
f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
|
||||||
|
)
|
||||||
|
video_mask = (
|
||||||
|
(input_ids == self.config.video_token_id)
|
||||||
|
.unsqueeze(-1)
|
||||||
|
.expand_as(inputs_embeds)
|
||||||
|
.to(inputs_embeds.device)
|
||||||
|
)
|
||||||
|
video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||||
|
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask.to(inputs_embeds.device)
|
||||||
|
|
||||||
|
# if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
|
||||||
|
if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
|
||||||
|
# calculate RoPE index once per generation in the pre-fill stage only
|
||||||
|
if (
|
||||||
|
(cache_position is not None and cache_position[0] == 0)
|
||||||
|
or self.rope_deltas is None
|
||||||
|
or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore
|
||||||
|
):
|
||||||
|
position_ids, rope_deltas = self.get_rope_index(
|
||||||
|
input_ids, image_grid_thw, video_grid_thw, attention_mask
|
||||||
|
)
|
||||||
|
self.rope_deltas = rope_deltas
|
||||||
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
||||||
|
else:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
delta = (
|
||||||
|
cache_position[0] + self.rope_deltas
|
||||||
|
if cache_position is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore
|
||||||
|
position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore
|
||||||
|
if cache_position is not None: # otherwise `deltas` is an int `0`
|
||||||
|
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore
|
||||||
|
delta = delta.to(position_ids.device) # type: ignore
|
||||||
|
position_ids = position_ids.add(delta) # type: ignore
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore
|
||||||
|
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs[0]
|
||||||
|
logits = None
|
||||||
|
loss = None
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states,
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||||
|
logits = logits.float()
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (logits,) + outputs[1:]
|
||||||
|
return (loss,) + output if loss is not None else output
|
||||||
|
|
||||||
|
return Qwen2VLCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
rope_deltas=self.rope_deltas,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen2_vl(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen2_vl import modeling_qwen2_vl
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration
|
||||||
|
), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal
|
||||||
|
return None
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
from transformers.models.qwen3 import modeling_qwen3
|
||||||
|
|
||||||
|
# Set the _PATCH_OPTS in the llama patch file
|
||||||
|
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
|
||||||
|
|
||||||
|
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
|
||||||
|
|
||||||
|
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen3.Qwen3ForCausalLM
|
||||||
|
), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(cce_forward, maybe_model)
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward
|
||||||
|
return None
|
||||||
@@ -0,0 +1,194 @@
|
|||||||
|
"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
|
from types import MethodType
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
|
from cut_cross_entropy.transformers.utils import (
|
||||||
|
PatchOptions,
|
||||||
|
TransformersModelT,
|
||||||
|
apply_lce,
|
||||||
|
)
|
||||||
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
|
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
|
||||||
|
_CONFIG_FOR_DOC,
|
||||||
|
QWEN3_MOE_INPUTS_DOCSTRING,
|
||||||
|
KwargsForCausalLM,
|
||||||
|
MoeCausalLMOutputWithPast,
|
||||||
|
MoeModelOutputWithPast,
|
||||||
|
load_balancing_loss_func,
|
||||||
|
)
|
||||||
|
from transformers.processing_utils import Unpack
|
||||||
|
from transformers.utils import (
|
||||||
|
add_start_docstrings_to_model_forward,
|
||||||
|
replace_return_docstrings,
|
||||||
|
)
|
||||||
|
from transformers.utils.deprecation import deprecate_kwarg
|
||||||
|
from transformers.utils.generic import can_return_tuple
|
||||||
|
|
||||||
|
_PATCH_OPTS: PatchOptions | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@can_return_tuple
|
||||||
|
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
|
||||||
|
@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING)
|
||||||
|
@replace_return_docstrings(
|
||||||
|
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
||||||
|
)
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[list[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
labels: Optional[torch.LongTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
|
logits_to_keep: Union[int, torch.Tensor] = 0,
|
||||||
|
**kwargs: Unpack[KwargsForCausalLM],
|
||||||
|
) -> MoeCausalLMOutputWithPast:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
|
logits_to_keep (`int` or `torch.Tensor`, *optional*):
|
||||||
|
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
|
||||||
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
|
||||||
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
|
||||||
|
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
|
||||||
|
This is useful when using packed tensor format (single dimension for batch and sequence length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM
|
||||||
|
|
||||||
|
>>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B")
|
||||||
|
|
||||||
|
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
||||||
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # Generate
|
||||||
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
|
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
||||||
|
```"""
|
||||||
|
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
|
outputs: MoeModelOutputWithPast = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
cache_position=cache_position,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = outputs.last_hidden_state
|
||||||
|
|
||||||
|
if hidden_states is None:
|
||||||
|
raise ValueError("hidden_states is None")
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
logits = None
|
||||||
|
|
||||||
|
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
||||||
|
slice_indices = (
|
||||||
|
slice(-logits_to_keep, None)
|
||||||
|
if isinstance(logits_to_keep, int)
|
||||||
|
else logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
|
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
|
||||||
|
assert labels is not None
|
||||||
|
loss = apply_lce(
|
||||||
|
hidden_states[:, slice_indices, :],
|
||||||
|
self.lm_head.weight,
|
||||||
|
labels,
|
||||||
|
_PATCH_OPTS,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.lm_head(hidden_states[:, slice_indices, :])
|
||||||
|
|
||||||
|
if labels is not None:
|
||||||
|
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)
|
||||||
|
|
||||||
|
aux_loss = None
|
||||||
|
if output_router_logits:
|
||||||
|
aux_loss = load_balancing_loss_func(
|
||||||
|
outputs.router_logits,
|
||||||
|
self.num_experts,
|
||||||
|
self.num_experts_per_tok,
|
||||||
|
attention_mask,
|
||||||
|
)
|
||||||
|
if labels is not None:
|
||||||
|
loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore
|
||||||
|
loss.device # type: ignore
|
||||||
|
) # make sure to reside in the same device
|
||||||
|
|
||||||
|
return MoeCausalLMOutputWithPast(
|
||||||
|
loss=loss,
|
||||||
|
aux_loss=aux_loss, # type: ignore
|
||||||
|
logits=logits,
|
||||||
|
past_key_values=outputs.past_key_values,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
router_logits=outputs.router_logits,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_qwen3_moe(
|
||||||
|
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
|
||||||
|
patch_options: PatchOptions,
|
||||||
|
) -> TransformersModelT | None:
|
||||||
|
global _PATCH_OPTS # pylint: disable=global-statement
|
||||||
|
|
||||||
|
from transformers.models.qwen3_moe import modeling_qwen3_moe
|
||||||
|
|
||||||
|
_PATCH_OPTS = patch_options
|
||||||
|
|
||||||
|
if isinstance(maybe_model, transformers.PreTrainedModel):
|
||||||
|
assert isinstance(
|
||||||
|
maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM
|
||||||
|
), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}."
|
||||||
|
maybe_model.forward = MethodType(forward, maybe_model)
|
||||||
|
|
||||||
|
return maybe_model
|
||||||
|
|
||||||
|
modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward
|
||||||
|
return None
|
||||||
Reference in New Issue
Block a user