From 1178a15ede612df0fdfb6dc65cb274c94a57e713 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Mon, 28 Apr 2025 23:18:46 +0700 Subject: [PATCH] Feat: Add qwen3 and CCE for qwen family (#2518) --- examples/qwen3/qlora-fsdp.yaml | 68 +++++ .../integrations/cut_cross_entropy/README.md | 7 +- .../cut_cross_entropy/monkeypatch/llama.py | 174 ++++++++++++ .../cut_cross_entropy/monkeypatch/patch.py | 26 +- .../cut_cross_entropy/monkeypatch/qwen2.py | 37 +++ .../monkeypatch/qwen2_5_vl.py | 246 +++++++++++++++++ .../monkeypatch/qwen2_moe.py | 188 +++++++++++++ .../cut_cross_entropy/monkeypatch/qwen2_vl.py | 249 ++++++++++++++++++ .../cut_cross_entropy/monkeypatch/qwen3.py | 35 +++ .../monkeypatch/qwen3_moe.py | 194 ++++++++++++++ 10 files changed, 1221 insertions(+), 3 deletions(-) create mode 100644 examples/qwen3/qlora-fsdp.yaml create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py create mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py diff --git a/examples/qwen3/qlora-fsdp.yaml b/examples/qwen3/qlora-fsdp.yaml new file mode 100644 index 000000000..dc3377b4f --- /dev/null +++ b/examples/qwen3/qlora-fsdp.yaml @@ -0,0 +1,68 @@ +base_model: Qwen/Qwen3-8B +# Automatically upload checkpoint and final model to HF +# hub_model_id: username/custom_model_name + +load_in_8bit: false +load_in_4bit: true +strict: false + +datasets: + - path: tatsu-lab/alpaca + type: alpaca +dataset_prepared_path: +val_set_size: 0.05 +output_dir: ./outputs/out + +sequence_len: 2048 +sample_packing: true +eval_sample_packing: true +pad_to_sequence_len: true + +adapter: qlora +lora_model_dir: +lora_r: 32 +lora_alpha: 64 +lora_dropout: 0.05 +lora_target_linear: true + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_torch_fused +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: auto +tf32: true + +gradient_checkpointing: true +gradient_checkpointing_kwargs: + use_reentrant: false +resume_from_checkpoint: +logging_steps: 1 +flash_attention: true + +warmup_steps: 10 +evals_per_epoch: 4 +saves_per_epoch: 1 +weight_decay: 0.0 +fsdp: + - full_shard + - auto_wrap +fsdp_config: + fsdp_limit_all_gathers: true + fsdp_sync_module_states: true + fsdp_offload_params: true + fsdp_use_orig_params: false + fsdp_cpu_ram_efficient_loading: true + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_sharding_strategy: FULL_SHARD +special_tokens: diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index 462bcbedc..627ebd935 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -32,8 +32,8 @@ plugins: ## Supported Models - llama -- llama4_text - llama4 +- llama4_text - mllama - phi3 - gemma @@ -43,6 +43,11 @@ plugins: - mistral - mistral3 - qwen2 +- qwen2_moe +- qwen2_vl +- qwen2_5_vl +- qwen3 +- qwen3_moe - cohere - cohere2 - glm diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py new file mode 100644 index 000000000..42ab996b9 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py @@ -0,0 +1,174 @@ +"""Llama CCE patch. Adapted from transformers v4.51.2""" + +# pylint: disable=duplicate-code + + +from types import MethodType +from typing import Optional, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.models.llama.modeling_llama import ( + _CONFIG_FOR_DOC, + LLAMA_INPUTS_DOCSTRING, + KwargsForCausalLM, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import can_return_tuple + +_PATCH_OPTS: PatchOptions | None = None + + +@can_return_tuple +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], +) -> CausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + if hidden_states is None: + raise ValueError("hidden_states is None") + + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **kwargs, + ) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def patch_llama( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + """Patch Llama for CCE.""" + global _PATCH_OPTS # pylint: disable=global-statement + from transformers.models.llama import modeling_llama + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_llama.LlamaForCausalLM + ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_llama.LlamaForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py index 9e18c6b0b..8176a1f0c 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py @@ -5,9 +5,7 @@ import transformers from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT -from cut_cross_entropy.transformers.llama import patch_llama from cut_cross_entropy.transformers.phi3 import patch_phi3 -from cut_cross_entropy.transformers.qwen2 import patch_qwen2 from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import ( @@ -24,6 +22,9 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import ( patch_glm, patch_glm4, ) +from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( + patch_llama, +) from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( patch_llama4, patch_llama4_text, @@ -33,6 +34,22 @@ from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import ( patch_mistral3, ) from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama +from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import ( + patch_qwen2, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import ( + patch_qwen2_5_vl, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import ( + patch_qwen2_moe, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import ( + patch_qwen2_vl, +) +from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3 +from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import ( + patch_qwen3_moe, +) CUT_CROSS_ENTROPY_MODEL_MAPPING = { "llama": patch_llama, @@ -47,6 +64,11 @@ CUT_CROSS_ENTROPY_MODEL_MAPPING = { "mistral": patch_mistral, "mistral3": patch_mistral3, "qwen2": patch_qwen2, + "qwen2_moe": patch_qwen2_moe, + "qwen2_vl": patch_qwen2_vl, + "qwen2_5_vl": patch_qwen2_5_vl, + "qwen3": patch_qwen3, + "qwen3_moe": patch_qwen3_moe, "cohere": patch_cohere, "cohere2": patch_cohere2, "glm": patch_glm, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py new file mode 100644 index 000000000..3f6d2b3e9 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py @@ -0,0 +1,37 @@ +"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" + +# pylint: disable=duplicate-code + +from types import MethodType + +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, +) + + +def patch_qwen2( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + from transformers.models.qwen2 import modeling_qwen2 + + # Set the _PATCH_OPTS in the llama patch file + import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch + + llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access + + from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( + cce_forward, + ) + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_qwen2.Qwen2ForCausalLM + ), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py new file mode 100644 index 000000000..16206006f --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py @@ -0,0 +1,246 @@ +"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2""" + +# pylint: disable=duplicate-code + + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from torch.nn import CrossEntropyLoss +from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( + Qwen2_5_VLCausalLMOutputWithPast, +) + +_PATCH_OPTS: PatchOptions | None = None + + +def cce_forward_multimodal( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + second_per_grid_ts: Optional[torch.Tensor] = None, +) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.dtype) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.dtype) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + second_per_grid_ts, + attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore + position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore + position_ids = position_ids.add(delta) # type: ignore + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = None + loss = None + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states, + self.lm_head.weight, + labels, + _PATCH_OPTS, + ) + else: + logits = self.lm_head(hidden_states) + + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2_5_VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +def patch_qwen2_5_vl( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + + from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration + ), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) + + return maybe_model + + modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = ( + cce_forward_multimodal + ) + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py new file mode 100644 index 000000000..0811bf55a --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py @@ -0,0 +1,188 @@ +"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2""" + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from transformers.models.qwen2_moe.modeling_qwen2_moe import ( + _CONFIG_FOR_DOC, + QWEN2MOE_INPUTS_DOCSTRING, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + load_balancing_loss_func, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import can_return_tuple + +_PATCH_OPTS: PatchOptions | None = None + + +@can_return_tuple +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **loss_kwargs, +) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM + + >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + ) + + hidden_states = outputs.last_hidden_state + loss = None + logits = None + + if hidden_states is None: + raise ValueError("hidden_states is None") + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **loss_kwargs, + ) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore + loss.device # type: ignore + ) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, # type: ignore + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +def patch_qwen2_moe( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + + from transformers.models.qwen2_moe import modeling_qwen2_moe + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM + ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(forward, maybe_model) + + return maybe_model + + modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py new file mode 100644 index 000000000..250c3ab6b --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py @@ -0,0 +1,249 @@ +"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2""" + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Tuple, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from torch.nn import CrossEntropyLoss +from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + _CONFIG_FOR_DOC, + QWEN2_VL_INPUTS_DOCSTRING, + Qwen2VLCausalLMOutputWithPast, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +_PATCH_OPTS: PatchOptions | None = None + + +@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def cce_forward_multimodal( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration + + >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") + + >>> messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": "What is shown in this image?"}, + ], + }, + ] + >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + + >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = ( + cache_position[0] + self.rope_deltas + if cache_position is not None + else 0 + ) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore + position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore + delta = delta.to(position_ids.device) # type: ignore + position_ids = position_ids.add(delta) # type: ignore + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore + + outputs = self.model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + logits = None + loss = None + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states, + self.lm_head.weight, + labels, + _PATCH_OPTS, + ) + else: + logits = self.lm_head(hidden_states) + + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return Qwen2VLCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + rope_deltas=self.rope_deltas, + ) + + +def patch_qwen2_vl( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + + from transformers.models.qwen2_vl import modeling_qwen2_vl + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration + ), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) + + return maybe_model + + modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py new file mode 100644 index 000000000..799a4f357 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py @@ -0,0 +1,35 @@ +"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" + +# pylint: disable=duplicate-code + +from types import MethodType + +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, +) + + +def patch_qwen3( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + from transformers.models.qwen3 import modeling_qwen3 + + # Set the _PATCH_OPTS in the llama patch file + import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch + + llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access + + from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_qwen3.Qwen3ForCausalLM + ), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(cce_forward, maybe_model) + return maybe_model + + modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward + return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py new file mode 100644 index 000000000..c5cd76f94 --- /dev/null +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py @@ -0,0 +1,194 @@ +"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2""" + +# pylint: disable=duplicate-code + +from types import MethodType +from typing import Optional, Union + +import torch +import transformers +from cut_cross_entropy.transformers.utils import ( + PatchOptions, + TransformersModelT, + apply_lce, +) +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.qwen3_moe.modeling_qwen3_moe import ( + _CONFIG_FOR_DOC, + QWEN3_MOE_INPUTS_DOCSTRING, + KwargsForCausalLM, + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + load_balancing_loss_func, +) +from transformers.processing_utils import Unpack +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) +from transformers.utils.deprecation import deprecate_kwarg +from transformers.utils.generic import can_return_tuple + +_PATCH_OPTS: PatchOptions | None = None + + +@can_return_tuple +@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") +@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[KwargsForCausalLM], +) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM + + >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + if hidden_states is None: + raise ValueError("hidden_states is None") + + loss = None + logits = None + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) + if isinstance(logits_to_keep, int) + else logits_to_keep + ) + + if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): + assert labels is not None + loss = apply_lce( + hidden_states[:, slice_indices, :], + self.lm_head.weight, + labels, + _PATCH_OPTS, + **kwargs, + ) + else: + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore + loss.device # type: ignore + ) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, # type: ignore + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + +def patch_qwen3_moe( + maybe_model: TransformersModelT | str | transformers.PretrainedConfig, + patch_options: PatchOptions, +) -> TransformersModelT | None: + global _PATCH_OPTS # pylint: disable=global-statement + + from transformers.models.qwen3_moe import modeling_qwen3_moe + + _PATCH_OPTS = patch_options + + if isinstance(maybe_model, transformers.PreTrainedModel): + assert isinstance( + maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM + ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." + maybe_model.forward = MethodType(forward, maybe_model) + + return maybe_model + + modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward + return None