From bb1109b81d4dd058323fe9c035e3d1dd00e66de4 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Wed, 25 Jun 2025 20:49:22 +0700 Subject: [PATCH] feat: update CCE to use axolotl's fork (#2813) [skip ci] * feat: update CCE to use axolotl's fork * chore: improve error message * feat: add eot token for gemma3 configs * fix: only warn on more than 1 image * fix: re-add gemma3 patch * Revert "fix: re-add gemma3 patch" This reverts commit f04db5e873bfab705a39b6b860fe8249796977d3. * feat: add qwen25 vl example * feat: point to upstream fork cce package * feat: update cce commit --- examples/gemma3/gemma-3-1b-qlora.yml | 2 + examples/gemma3/gemma-3-4b-qlora.yml | 2 + examples/gemma3/gemma-3-4b-vision-qlora.yml | 2 + examples/qwen2_5-vl/lora-7b.yaml | 55 +++ scripts/cutcrossentropy_install.py | 2 +- .../integrations/cut_cross_entropy/README.md | 32 +- .../cut_cross_entropy/__init__.py | 24 +- .../cut_cross_entropy/monkeypatch/__init__.py | 0 .../cut_cross_entropy/monkeypatch/cohere.py | 191 -------- .../cut_cross_entropy/monkeypatch/gemma.py | 165 ------- .../cut_cross_entropy/monkeypatch/gemma3.py | 447 ------------------ .../cut_cross_entropy/monkeypatch/glm4.py | 57 --- .../cut_cross_entropy/monkeypatch/llama.py | 164 ------- .../cut_cross_entropy/monkeypatch/llama4.py | 401 ---------------- .../cut_cross_entropy/monkeypatch/mistral3.py | 384 --------------- .../cut_cross_entropy/monkeypatch/mllama.py | 366 -------------- .../cut_cross_entropy/monkeypatch/patch.py | 126 ----- .../cut_cross_entropy/monkeypatch/qwen2.py | 37 -- .../monkeypatch/qwen2_5_vl.py | 246 ---------- .../monkeypatch/qwen2_moe.py | 178 ------- .../cut_cross_entropy/monkeypatch/qwen2_vl.py | 239 ---------- .../cut_cross_entropy/monkeypatch/qwen3.py | 35 -- .../monkeypatch/qwen3_moe.py | 183 ------- .../cut_cross_entropy/monkeypatch/utils.py | 40 -- src/axolotl/processing_strategies.py | 2 +- 25 files changed, 94 insertions(+), 3286 deletions(-) create mode 100644 examples/qwen2_5-vl/lora-7b.yaml delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py delete mode 100644 src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py diff --git a/examples/gemma3/gemma-3-1b-qlora.yml b/examples/gemma3/gemma-3-1b-qlora.yml index 44310558c..217c887aa 100644 --- a/examples/gemma3/gemma-3-1b-qlora.yml +++ b/examples/gemma3/gemma-3-1b-qlora.yml @@ -13,6 +13,8 @@ load_in_4bit: true # huggingface repo chat_template: gemma3 +eot_tokens: + - datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template diff --git a/examples/gemma3/gemma-3-4b-qlora.yml b/examples/gemma3/gemma-3-4b-qlora.yml index 0d89d9ffb..d78559ae3 100644 --- a/examples/gemma3/gemma-3-4b-qlora.yml +++ b/examples/gemma3/gemma-3-4b-qlora.yml @@ -6,6 +6,8 @@ load_in_4bit: true ddp_find_unused_parameters: true chat_template: gemma3 +eot_tokens: + - datasets: - path: cgato/SlimOrcaDedupCleaned type: chat_template diff --git a/examples/gemma3/gemma-3-4b-vision-qlora.yml b/examples/gemma3/gemma-3-4b-vision-qlora.yml index 339df92e5..183eb88e8 100644 --- a/examples/gemma3/gemma-3-4b-vision-qlora.yml +++ b/examples/gemma3/gemma-3-4b-vision-qlora.yml @@ -12,6 +12,8 @@ sample_packing: false ddp_find_unused_parameters: true chat_template: gemma3 +eot_tokens: + - datasets: - path: HuggingFaceH4/llava-instruct-mix-vsft type: chat_template diff --git a/examples/qwen2_5-vl/lora-7b.yaml b/examples/qwen2_5-vl/lora-7b.yaml new file mode 100644 index 000000000..25d02805f --- /dev/null +++ b/examples/qwen2_5-vl/lora-7b.yaml @@ -0,0 +1,55 @@ +base_model: Qwen/Qwen2.5-VL-7B-Instruct +processor_type: AutoProcessor + +# these 3 lines are needed for now to handle vision chat templates w images +skip_prepare_dataset: true +remove_unused_columns: false +sample_packing: false + +chat_template: qwen2_vl +datasets: + - path: HuggingFaceH4/llava-instruct-mix-vsft + type: chat_template + split: train[:1%] + field_messages: messages +dataset_prepared_path: last_run_prepared +val_set_size: 0.0 +output_dir: ./outputs/out + +adapter: lora +lora_model_dir: + +sequence_len: 8192 +pad_to_sequence_len: false + +lora_r: 32 +lora_alpha: 16 +lora_dropout: 0.05 +lora_target_modules: 'model.language_model.layers.[\d]+.(mlp|cross_attn|self_attn).(up|down|gate|q|k|v|o)_proj' + +wandb_project: +wandb_entity: +wandb_watch: +wandb_name: +wandb_log_model: + +gradient_accumulation_steps: 4 +micro_batch_size: 1 +num_epochs: 1 +optimizer: adamw_bnb_8bit +lr_scheduler: cosine +learning_rate: 0.0002 + +bf16: true +fp16: +tf32: true + +gradient_checkpointing: true +logging_steps: 1 +flash_attention: true +eager_attention: + +warmup_ratio: 0.1 +evals_per_epoch: 1 +saves_per_epoch: 1 +weight_decay: 0.0 diff --git a/scripts/cutcrossentropy_install.py b/scripts/cutcrossentropy_install.py index 4a92746c1..bb9224bb0 100644 --- a/scripts/cutcrossentropy_install.py +++ b/scripts/cutcrossentropy_install.py @@ -29,5 +29,5 @@ UV_PREFIX = "uv " if USE_UV else "" print( UNINSTALL_PREFIX - + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@a1174ca"' + + f'{UV_PREFIX}pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154"' ) diff --git a/src/axolotl/integrations/cut_cross_entropy/README.md b/src/axolotl/integrations/cut_cross_entropy/README.md index bddf3ced2..b5e3ecda8 100644 --- a/src/axolotl/integrations/cut_cross_entropy/README.md +++ b/src/axolotl/integrations/cut_cross_entropy/README.md @@ -19,19 +19,11 @@ python scripts/cutcrossentropy_install.py | sh - If you are installing from pip ```bash -pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438" +pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@78b2a45713a54c9bedf8b33f5e31cf07a1a57154" ``` ## Usage -**NOTE**: If you are training a VLM model, please use older version of Axolotl as upstream has applied a major VLM refactor, and our patches have not been updated yet. - -```bash -git checkout 787880215b3ab32ccaf81c1b2e9588c6f3e6e764 - -pip3 install --no-build-isolation -e . -``` - ```yaml plugins: - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin @@ -39,27 +31,29 @@ plugins: ## Supported Models -- llama -- llama4 -- llama4_text -- mllama -- phi3 +- cohere +- cohere2 - gemma - gemma2 - gemma3 - gemma3_text +- glm +- glm4 +- llama +- llama4 +- llama4_text - mistral - mistral3 +- mllama +- phi +- phi3 +- phi4_multimodal - qwen2 -- qwen2_moe - qwen2_vl +- qwen2_moe - qwen2_5_vl - qwen3 - qwen3_moe -- cohere -- cohere2 -- glm -- glm4 ## Citation diff --git a/src/axolotl/integrations/cut_cross_entropy/__init__.py b/src/axolotl/integrations/cut_cross_entropy/__init__.py index c29bb55d4..37f4dba68 100644 --- a/src/axolotl/integrations/cut_cross_entropy/__init__.py +++ b/src/axolotl/integrations/cut_cross_entropy/__init__.py @@ -31,8 +31,8 @@ from .args import CutCrossEntropyArgs # pylint: disable=unused-import. # noqa: LOG = get_logger(__name__) _CCE_INSTALL_MESSAGE = ( - "Please install cut_cross_entropy with transformers support using " - '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@bad6f7b49c75fdec69471abb71b4cddd0f0c6438"`' + "Please install Axolotl's fork of cut_cross_entropy with transformers support using " + '`pip install "cut-cross-entropy[transformers] @ git+https://github.com/axolotl-ai-cloud/ml-cross-entropy.git@7f6afce"`' ) @@ -64,16 +64,28 @@ class CutCrossEntropyPlugin(BasePlugin): "cut_cross_entropy.transformers" ) if cce_spec_transformers is None: - raise ImportError(_CCE_INSTALL_MESSAGE) + raise ImportError( + "Transformers support is not installed. " + _CCE_INSTALL_MESSAGE + ) + + # Check if Axolotl's cce fork is installed + try: + from cut_cross_entropy.transformers.patch import AXOLOTL_CCE_FORK + + if not AXOLOTL_CCE_FORK: + raise ImportError + except ImportError as e: + raise ImportError( + "Axolotl's fork of cut_cross_entropy is not installed. " + + _CCE_INSTALL_MESSAGE + ) from e def pre_model_load(self, cfg): """Apply cut cross entropy before model loading if enabled.""" if cfg.cut_cross_entropy: self._check_requirements() - from .monkeypatch.patch import ( - cce_patch, - ) + from cut_cross_entropy.transformers.patch import cce_patch LOG.info( f"Applying Cut Cross Entropy to model type: {cfg.model_config_type}" diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py deleted file mode 100644 index ea9e10724..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ /dev/null @@ -1,191 +0,0 @@ -"""Cohere and Cohere2 CCE patch.""" - -# This patch is based off transformers 4.50.0. -# It patches the forward function for CohereForCausalLM and Cohere2ForCausalLM. -# It scales the hidden states by the logit scale in advance instead of the logits as the -# operation is done internally and should be mathematically equivalent. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.cohere.modeling_cohere import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >> from transformers import AutoTokenizer, CohereForCausalLM - - >> model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01") - >> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01") - - >> prompt = "Hey, are you conscious? Can you talk to me?" - >> inputs = tokenizer(prompt, return_tensors="pt") - - >> # Generate - >> generate_ids = model.generate(inputs.input_ids, max_length=30) - >> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - # scale hidden_states by logit_scale in-place of logits - loss = apply_lce( - hidden_states[:, slice_indices, :] * self.logit_scale, - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - logits = logits * self.logit_scale # main diff from Llama - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_cohere( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.cohere import modeling_cohere - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_cohere.CohereForCausalLM - ), f"Expected a CohereForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_cohere.CohereForCausalLM.forward = cce_forward - return None - - -def patch_cohere2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.cohere2 import modeling_cohere2 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_cohere2.Cohere2ForCausalLM - ), f"Expected a Cohere2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_cohere2.Cohere2ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py deleted file mode 100644 index ae3d8c6ef..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py +++ /dev/null @@ -1,165 +0,0 @@ -"""Gemma CCE patch""" - -# This patch is based off transformers 4.50.0. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma.modeling_gemma import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, GemmaForCausalLM - - >>> model = GemmaForCausalLM.from_pretrained("google/gemma-7b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_gemma( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma import modeling_gemma - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma.GemmaForCausalLM - ), f"Expected a GemmaForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma.GemmaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py deleted file mode 100644 index 644e5cce7..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py +++ /dev/null @@ -1,447 +0,0 @@ -"""Gemma2 and Gemma3 (text and multimodal) CCE patch.""" - -# Implementation originally adapted from https://github.com/apple/ml-cross-entropy/pull/29 -# and updated for transformers 4.50.0. -# This is a modified version of the patch that allows for deferred logits calculation for gemma3 and works -# with both gemma3 (text and multimodal) models. - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) -from torch import nn -from transformers.cache_utils import Cache, HybridCache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.gemma3.modeling_gemma3 import ( - Gemma3CausalLMOutputWithPast, - logger, -) -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - -from axolotl.integrations.cut_cross_entropy.monkeypatch.utils import apply_lce - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[HybridCache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Gemma3ForCausalLM - - >>> model = Gemma3ForCausalLM.from_pretrained("google/gemma-2-9b") - >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - >>> prompt = "What is your favorite condiment?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is your favorite condiment?" - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **loss_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - softcap=getattr(self.config, "final_logit_softcapping", None), - **loss_kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if self.config.final_logit_softcapping is not None: - logits = logits / self.config.final_logit_softcapping - logits = torch.tanh(logits) - logits = logits * self.config.final_logit_softcapping - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None, - token_type_ids: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **lm_kwargs, -) -> Union[Tuple, Gemma3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.text_config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.text_config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Gemma3ForConditionalGeneration - - >>> model = Gemma3ForConditionalGeneration.from_pretrained("google/Gemma3-test-224px-hf") - >>> processor = AutoProcessor.from_pretrained("google/Gemma3-test-224px-hf") - - >>> prompt = "answer en Where is the cow standing?" - >>> url = "https://huggingface.co/gv-hf/Gemma3-test-224px-hf/resolve/main/cow_beach_1.png" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "answer en Where is the cow standing?\nbeach" - ```""" - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - is_training = token_type_ids is not None and labels is not None - - # Replace image id woth PAD if the image token if OOV, to avoid index-errors - if input_ids is not None and self.config.image_token_index >= self.vocab_size: - special_image_mask = input_ids == self.config.image_token_index - llm_input_ids = input_ids.clone() - llm_input_ids[special_image_mask] = 0 - else: - llm_input_ids = input_ids # type: ignore - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(llm_input_ids) - - if cache_position is None: - past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 # type: ignore - ) - cache_position = torch.arange( # type: ignore - past_seen_tokens, - past_seen_tokens + inputs_embeds.shape[1], - device=inputs_embeds.device, - ) - - # Merge text and images - if pixel_values is not None: - image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor( - self.config.image_token_index, - dtype=torch.long, - device=inputs_embeds.device, - ) - ) - else: - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze( - -1 - ) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore - - # mask out pad-token-ids in labels for BC - if labels is not None and self.pad_token_id in labels: - logger.warning_once( - "`labels` contains `pad_token_id` which will be masked with `config.ignore_index`. " - "You have to mask out `pad_token_id` when preparing `labels`, this behavior will be removed in v.4.46.", - ) - labels = torch.where( # type: ignore - input_ids == self.pad_token_id, self.config.ignore_index, labels - ) - - causal_mask = self._update_causal_mask( # pylint: disable=protected-access - attention_mask, - token_type_ids, - past_key_values, - cache_position, - inputs_embeds, - is_training, - ) - outputs = self.language_model( - attention_mask=causal_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - softcap=getattr(self.config, "final_logit_softcapping", None), - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - shift_logits = logits[..., :-1, :] - shift_labels = labels[..., 1:] - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -shift_logits.shape[1] :].to( - logits.device - ) - shift_logits = shift_logits[ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = shift_labels[ - shift_attention_mask.to(shift_labels.device) != 0 - ].contiguous() - else: - shift_logits = shift_logits.contiguous() - shift_labels = shift_labels.contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - - flat_logits = shift_logits.view(-1, self.config.text_config.vocab_size) - flat_labels = shift_labels.view(-1).to(shift_logits.device) - loss = loss_fct(flat_logits, flat_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Gemma3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_gemma2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma2 import modeling_gemma2 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma2.Gemma2ForCausalLM - ), f"Expected a Gemma2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma2.Gemma2ForCausalLM.forward = cce_forward - return None - - -def patch_gemma3_text( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma3 import modeling_gemma3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma3.Gemma3ForCausalLM - ), f"Expected a Gemma3ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward - return None - - -def patch_gemma3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.gemma3 import modeling_gemma3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_gemma3.Gemma3ForConditionalGeneration - ), f"Expected a Gemma3ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the causal model to enable deferred logits calculation - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_gemma3.Gemma3ForConditionalGeneration.forward = cce_forward_multimodal - # patch the causal model to enable deferred logits calculation - modeling_gemma3.Gemma3ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py deleted file mode 100644 index 3df909f88..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/glm4.py +++ /dev/null @@ -1,57 +0,0 @@ -"""GLM 4 patch. GLM family inherits from Llama.""" - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_glm( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - # Set the _PATCH_OPTS in the llama patch file - import cut_cross_entropy.transformers.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from cut_cross_entropy.transformers.llama import cce_forward - from transformers.models.glm import modeling_glm - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_glm.GlmForCausalLM - ), f"Expected a GlmForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_glm.GlmForCausalLM.forward = cce_forward - return None - - -def patch_glm4( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - # Set the _PATCH_OPTS in the llama patch file - import cut_cross_entropy.transformers.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from cut_cross_entropy.transformers.llama import cce_forward - from transformers.models.glm4 import modeling_glm4 - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_glm4.Glm4ForCausalLM - ), f"Expected a Glm4ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_glm4.Glm4ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py deleted file mode 100644 index bed411ace..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py +++ /dev/null @@ -1,164 +0,0 @@ -"""Llama CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, -) -from transformers.models.llama.modeling_llama import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> CausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM - - >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - if hidden_states is None: - raise ValueError("hidden_states is None") - - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_llama( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - """Patch Llama for CCE.""" - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama import modeling_llama - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama.LlamaForCausalLM - ), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_llama.LlamaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py deleted file mode 100644 index 3143e9c8d..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ /dev/null @@ -1,401 +0,0 @@ -"""Llama4 CCE patch. Adapted from transformers 4.51.0.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.llama4.modeling_llama4 import ( - Llama4CausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*, defaults to `False`): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Llama4ForCausalLM - - >>> model = Llama4ForCausalLM.from_pretrained("meta-llama4/Llama4-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama4/Llama4-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, # type: ignore - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, list[int]]] = None, - vision_feature_select_strategy: Optional[str] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor | None = None, - **lm_kwargs, -) -> Union[Tuple, Llama4CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, LlavaForConditionalGeneration - - >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") - >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - - >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_config.vision_feature_layer - ) - vision_feature_select_strategy = ( - vision_feature_select_strategy - if vision_feature_select_strategy is not None - else self.config.vision_config.vision_feature_select_strategy - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) # type: ignore - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - vision_feature_select_strategy=vision_feature_select_strategy, - image_sizes=image_sizes, - ) - original_inputs_embeds_shape = inputs_embeds.shape # type: ignore - - vision_flat = image_features.view(-1, image_features.size(-1)) - projected_vision_flat = self.multi_modal_projector(vision_flat) - - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - final_mask = special_image_mask.to(inputs_embeds.device) # type: ignore - inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) # type: ignore - - final_mask_1d = final_mask[..., 0].reshape(-1) - num_tokens_to_fill = final_mask_1d.sum() - - if num_tokens_to_fill != projected_vision_flat.size(0): - raise ValueError( - f"Mismatch: final_mask wants {num_tokens_to_fill} embeddings, " - f"but multi_modal_projector returned {projected_vision_flat.size(0)}" - ) - - expanded_mask = final_mask_1d.unsqueeze(-1).expand(-1, inputs_embeds.size(-1)) - inputs_embeds = inputs_embeds.masked_scatter( - expanded_mask, projected_vision_flat - ) # type: ignore - inputs_embeds = inputs_embeds.view(original_inputs_embeds_shape) # type: ignore - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - # TODO: check if need to handle attention_mask - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Llama4CausalLMOutputWithPast( - loss=loss, - logits=logits, # type: ignore # TODO: check if need to create dummy logits - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_llama4_text( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama4 import modeling_llama4 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama4.Llama4ForCausalLM - ), f"Expected a Llama4ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - - return maybe_model - - setattr( - modeling_llama4.Llama4ForCausalLM, - "forward", - cce_forward, - ) - return None - - -def patch_llama4( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.llama4 import modeling_llama4 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_llama4.Llama4ForConditionalGeneration - ), f"Expected a Llama4ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the language model - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - setattr( - modeling_llama4.Llama4ForConditionalGeneration, - "forward", - cce_forward_multimodal, - ) - - # patch the causal language model - setattr(modeling_llama4.Llama4ForCausalLM, "forward", cce_forward) - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py deleted file mode 100644 index aa252701e..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py +++ /dev/null @@ -1,384 +0,0 @@ -"""Mistral and Mistral3 CCE patch.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch import nn -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mistral3.modeling_mistral3 import ( - Mistral3CausalLMOutputWithPast, -) -from transformers.models.mistral.modeling_mistral import ( - KwargsForCausalLM, -) -from transformers.processing_utils import Unpack -from transformers.utils import ( - is_torchdynamo_compiling, -) -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] | None = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **kwargs: Unpack[KwargsForCausalLM], -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MistralForCausalLM - - >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf") - >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - if labels is not None: - loss = self.loss_function( - logits=logits, - labels=labels, - vocab_size=self.config.vocab_size, - **kwargs, - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def cce_forward_multimodal( - self, - input_ids: torch.LongTensor | None = None, - pixel_values: torch.FloatTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - vision_feature_layer: Optional[Union[int, list[int]]] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - image_sizes: torch.Tensor | None = None, - **lm_kwargs, -) -> Union[Tuple, Mistral3CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Mistral3ForConditionalGeneration - - >>> model = Mistral3ForConditionalGeneration.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - >>> processor = AutoProcessor.from_pretrained("mistralai/Mistral-Small-3.1-24B-Instruct-2503") - - >>> prompt = "[INST][IMG]What is the image?[/INST]" - >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(images=image, text=prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_new_tokens=15) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "What is the image?The image depicts two cats lying on a pink blanket." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - vision_feature_layer = ( - vision_feature_layer - if vision_feature_layer is not None - else self.config.vision_feature_layer - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if pixel_values is not None: - image_features = self.get_image_features( - pixel_values=pixel_values, - vision_feature_layer=vision_feature_layer, - image_sizes=image_sizes, - ) - - special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(inputs_embeds).to( - inputs_embeds.device - ) - if ( - not is_torchdynamo_compiling() - and inputs_embeds[special_image_mask].numel() != image_features.numel() - ): - n_image_tokens = (input_ids == self.config.image_token_index).sum() - n_image_features = image_features.shape[0] * image_features.shape[1] - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) # type: ignore - - outputs = self.language_model( - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **lm_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **lm_kwargs, - ) - else: - logits = hidden_states - if labels is not None: - # Shift so that tokens < n predict n - if attention_mask is not None: - # we use the input attention mask to shift the logits and labels, because it is 2D. - # we also crop attn mask in case it is longer, which happens in PrefixTuning with peft - shift_attention_mask = attention_mask[:, -(logits.shape[1] - 1) :].to( - logits.device - ) - shift_logits = logits[..., :-1, :][ - shift_attention_mask.to(logits.device) != 0 - ].contiguous() - shift_labels = labels[..., 1:][ - shift_attention_mask.to(labels.device) != 0 - ].contiguous() - else: - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), - shift_labels.view(-1).to(shift_logits.device), - ) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Mistral3CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=image_features if pixel_values is not None else None, - ) - - -def patch_mistral( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mistral import modeling_mistral - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mistral.MistralForCausalLM - ), f"Expected a MistralForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_mistral.MistralForCausalLM.forward = cce_forward - return None - - -def patch_mistral3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mistral import modeling_mistral - from transformers.models.mistral3 import modeling_mistral3 - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mistral3.Mistral3ForConditionalGeneration - ), f"Expected a Mistral3ForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the causal model to enable deferred logits calculation - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_mistral3.Mistral3ForConditionalGeneration.forward = cce_forward_multimodal - # patch the causal model to enable deferred logits calculation - modeling_mistral.MistralForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py deleted file mode 100644 index e82853e6c..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mllama.py +++ /dev/null @@ -1,366 +0,0 @@ -"""Mllama CCE patch.""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.cache_utils import Cache -from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.models.mllama.modeling_mllama import ( - _prepare_cross_attention_mask, -) -from transformers.utils.deprecation import deprecate_kwarg - -_PATCH_OPTS: PatchOptions | None = None - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - cross_attention_states: Optional[torch.LongTensor] = None, - cross_attention_mask: Optional[torch.LongTensor] = None, - full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - defer_logits_calculation: bool = False, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - defer_logits_calculation (`bool`, *optional*): - If `True`, defer logits calculation to the ConditionalGeneration forward. This is used to avoid the - memory overhead of calculating logits using regular lm_head forward pass and to use CCE. - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, MllamaForCausalLM - - >>> model = MllamaForCausalLM.from_pretrained("Llama-3.2-11B-Vision") - >>> tokenizer = AutoTokenizer.from_pretrained("Llama-3.2-11B-Vision") - - >>> prompt = "If I had to write a haiku, it would be:" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=40, do_sample=True, temperature=0.6) - >>> result = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - >>> print(result) - If I had to write a haiku, it would be: "Snowflakes gently fall" - simple, yet peaceful. - I love the idea of snowflakes gently falling, each one - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( - input_ids=input_ids, - cross_attention_states=cross_attention_states, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - elif _PATCH_OPTS is not None and defer_logits_calculation: - # defer logits calculation to the ConditionalGeneration forward - logits = hidden_states[:, slice_indices, :] - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]).float() - - loss = None - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - pixel_values: Optional[torch.FloatTensor] = None, - aspect_ratio_mask: Optional[torch.Tensor] = None, - aspect_ratio_ids: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_mask: Optional[torch.Tensor] = None, - cross_attention_states: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> Union[Tuple, CausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, MllamaForConditionalGeneration - - >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" - >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) - >>> processor = AutoProcessor.from_pretrained(checkpoint) - - >>> prompt = "<|image|>If I had to write a haiku for this one" - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> inputs = processor(text=prompt, images=image, return_tensors="pt") - - >>> # Generate - >>> output = model.generate(**inputs, max_new_tokens=15) - - >>> prompt_len = inputs.input_ids.shape[-1] - >>> generated_ids = output[:, prompt_len:] - >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) - >>> print(generated_text) - [', it would be:.\\nA stop sign in Chinatown.\\n'] - ``` - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if pixel_values is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" - ) - - if pixel_values is not None and cross_attention_states is not None: - raise ValueError( - "`pixel_values` and `cross_attention_states` cannot be provided simultaneously" - ) - - if pixel_values is not None: - if aspect_ratio_ids is None: - raise ValueError( - "`aspect_ratio_ids` must be provided if `pixel_values` is provided" - ) - # get vision tokens from vision model - vision_outputs = self.vision_model( - pixel_values=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - ) - cross_attention_states = vision_outputs[0] - cross_attention_states = self.multi_modal_projector( - cross_attention_states - ).reshape( - -1, cross_attention_states.shape[-2], self.hidden_size # type: ignore - ) - - if cross_attention_mask is not None: - cross_attention_mask, full_text_row_masked_out_mask = ( - _prepare_cross_attention_mask( - cross_attention_mask, - num_vision_tokens=self.vision_model.num_patches, - dtype=self.dtype, - ) - ) - else: - full_text_row_masked_out_mask = None - - if cross_attention_mask is not None and cache_position is not None: - cross_attention_mask = cross_attention_mask[:, :, cache_position] - full_text_row_masked_out_mask = full_text_row_masked_out_mask[ - :, :, cache_position - ] - - outputs = self.language_model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - cross_attention_states=cross_attention_states, - cross_attention_mask=cross_attention_mask, - full_text_row_masked_out_mask=full_text_row_masked_out_mask, - past_key_values=past_key_values, - use_cache=use_cache, - inputs_embeds=inputs_embeds, - output_hidden_states=output_hidden_states, - output_attentions=output_attentions, - return_dict=return_dict, - cache_position=cache_position, - logits_to_keep=logits_to_keep, - defer_logits_calculation=True, # enable deferred logits calculation - **loss_kwargs, - ) - - hidden_states = outputs[0] - loss = None - logits = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.language_model.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - else: - # Temporary fix to calculate the loss in main class, as the model's vocab size may be resized - logits = hidden_states - - if labels is not None: - loss = self.loss_function( - logits, labels, self.config.get_text_config().vocab_size, **loss_kwargs - ) - - if not return_dict: - return (loss,) + outputs if loss is not None else outputs - - return CausalLMOutputWithPast( - loss=loss, - logits=outputs.logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -def patch_mllama( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - - global _PATCH_OPTS # pylint: disable=global-statement - from transformers.models.mllama import modeling_mllama - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_mllama.MllamaForConditionalGeneration - ), f"Expected a MllamaForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - # patch the language model - maybe_model.language_model.forward = MethodType( - cce_forward, maybe_model.language_model - ) - return maybe_model - - modeling_mllama.MllamaForConditionalGeneration.forward = cce_forward_multimodal - - # patch the causal language model - modeling_mllama.MllamaForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py deleted file mode 100644 index 8176a1f0c..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. - -"""Cut Cross Entropy patcher""" - -import transformers -from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl -from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT -from cut_cross_entropy.transformers.phi3 import patch_phi3 -from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT - -from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import ( - patch_cohere, - patch_cohere2, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma import patch_gemma -from axolotl.integrations.cut_cross_entropy.monkeypatch.gemma3 import ( - patch_gemma2, - patch_gemma3, - patch_gemma3_text, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.glm4 import ( - patch_glm, - patch_glm4, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( - patch_llama, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import ( - patch_llama4, - patch_llama4_text, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.mistral3 import ( - patch_mistral, - patch_mistral3, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import ( - patch_qwen2, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import ( - patch_qwen2_5_vl, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import ( - patch_qwen2_moe, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import ( - patch_qwen2_vl, -) -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3 -from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import ( - patch_qwen3_moe, -) - -CUT_CROSS_ENTROPY_MODEL_MAPPING = { - "llama": patch_llama, - "llama4": patch_llama4, - "llama4_text": patch_llama4_text, - "mllama": patch_mllama, - "phi3": patch_phi3, - "gemma": patch_gemma, - "gemma2": patch_gemma2, - "gemma3": patch_gemma3, - "gemma3_text": patch_gemma3_text, - "mistral": patch_mistral, - "mistral3": patch_mistral3, - "qwen2": patch_qwen2, - "qwen2_moe": patch_qwen2_moe, - "qwen2_vl": patch_qwen2_vl, - "qwen2_5_vl": patch_qwen2_5_vl, - "qwen3": patch_qwen3, - "qwen3_moe": patch_qwen3_moe, - "cohere": patch_cohere, - "cohere2": patch_cohere2, - "glm": patch_glm, - "glm4": patch_glm4, -} - - -def cce_patch( - model_type_or_model: str | TransformersModelT | transformers.PretrainedConfig, - impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, - reduction: str = "mean", - filter_eps: float | str | None = "auto", - accum_e_fp32: bool = False, - accum_c_fp32: bool = False, - filter_e_grad: bool = True, - filter_c_grad: bool = True, - train_only: bool = False, -) -> TransformersModelT | None: - if isinstance(impl, LinearCrossEntropyImpl): - impl = impl.name.lower() - - if impl not in (v.name.lower() for v in LinearCrossEntropyImpl): - raise ValueError(f"Unknown {impl=}") - - if isinstance(model_type_or_model, transformers.PreTrainedModel): - if hasattr(model_type_or_model, "config"): - model_type = getattr( - getattr(model_type_or_model, "config", None), "model_type", None - ) - else: - raise ValueError( - "model_type_or_model is a PreTrainedModel but does not have a config attribute" - ) - elif isinstance(model_type_or_model, transformers.PretrainedConfig): - model_type = model_type_or_model.model_type - else: - model_type = model_type_or_model - - patch_options = PatchOptions( - impl=impl, - reduction=reduction, - filter_eps=filter_eps, - accum_e_fp32=accum_e_fp32, - accum_c_fp32=accum_c_fp32, - filter_e_grad=filter_e_grad, - filter_c_grad=filter_c_grad, - train_only=train_only, - ) - - if model_type in CUT_CROSS_ENTROPY_MODEL_MAPPING: - return CUT_CROSS_ENTROPY_MODEL_MAPPING[model_type]( - model_type_or_model, patch_options - ) - - raise RuntimeError(f"Unknown model type {model_type}") diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py deleted file mode 100644 index 3f6d2b3e9..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" - -# pylint: disable=duplicate-code - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_qwen2( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - from transformers.models.qwen2 import modeling_qwen2 - - # Set the _PATCH_OPTS in the llama patch file - import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import ( - cce_forward, - ) - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2.Qwen2ForCausalLM - ), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py deleted file mode 100644 index 16206006f..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_5_vl.py +++ /dev/null @@ -1,246 +0,0 @@ -"""Qwen2.5 VL CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( - Qwen2_5_VLCausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, - second_per_grid_ts: Optional[torch.Tensor] = None, -) -> Union[Tuple, Qwen2_5_VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration - - >>> model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.dtype) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - - mask = input_ids == self.config.image_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - image_mask = mask_expanded.to(inputs_embeds.device) - - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.dtype) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - - mask = input_ids == self.config.video_token_id - mask_unsqueezed = mask.unsqueeze(-1) - mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) - video_mask = mask_expanded.to(inputs_embeds.device) - - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, - image_grid_thw, - video_grid_thw, - second_per_grid_ts, - attention_mask, - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - (cache_position[0] + self.rope_deltas).to(inputs_embeds.device) - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore - position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore - position_ids = position_ids.add(delta) # type: ignore - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = None - loss = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.lm_head.weight, - labels, - _PATCH_OPTS, - ) - else: - logits = self.lm_head(hidden_states) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2_5_VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=self.rope_deltas, - ) - - -def patch_qwen2_5_vl( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration - ), f"Expected a Qwen2_5_VLForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - return maybe_model - - modeling_qwen2_5_vl.Qwen2_5_VLForConditionalGeneration.forward = ( - cce_forward_multimodal - ) - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py deleted file mode 100644 index afe56266e..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py +++ /dev/null @@ -1,178 +0,0 @@ -"""Qwen2 MoE CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - load_balancing_loss_func, -) -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **loss_kwargs, -) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen2MoeForCausalLM - - >>> model = Qwen2MoeForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - cache_position=cache_position, - ) - - hidden_states = outputs.last_hidden_state - loss = None - logits = None - - if hidden_states is None: - raise ValueError("hidden_states is None") - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **loss_kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits, - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore - loss.device # type: ignore - ) # make sure to reside in the same device - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, # type: ignore - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -def patch_qwen2_moe( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_moe import modeling_qwen2_moe - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_moe.Qwen2MoeForCausalLM - ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(forward, maybe_model) - - return maybe_model - - modeling_qwen2_moe.Qwen2MoeForCausalLM.forward = forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py deleted file mode 100644 index 79af01cfa..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Qwen2 VL CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Tuple, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from torch.nn import CrossEntropyLoss -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLCausalLMOutputWithPast, -) - -_PATCH_OPTS: PatchOptions | None = None - - -def cce_forward_multimodal( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - pixel_values: Optional[torch.Tensor] = None, - pixel_values_videos: Optional[torch.FloatTensor] = None, - image_grid_thw: Optional[torch.LongTensor] = None, - video_grid_thw: Optional[torch.LongTensor] = None, - rope_deltas: Optional[torch.LongTensor] = None, - cache_position: Optional[torch.LongTensor] = None, -) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Returns: - - Example: - - ```python - >>> from PIL import Image - >>> import requests - >>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration - - >>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") - - >>> messages = [ - { - "role": "user", - "content": [ - {"type": "image"}, - {"type": "text", "text": "What is shown in this image?"}, - ], - }, - ] - >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" - >>> image = Image.open(requests.get(url, stream=True).raw) - - >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - if inputs_embeds is None: - inputs_embeds = self.model.embed_tokens(input_ids) - if pixel_values is not None: - pixel_values = pixel_values.type(self.visual.get_dtype()) - image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) - n_image_tokens = (input_ids == self.config.image_token_id).sum().item() - n_image_features = image_embeds.shape[0] - if n_image_tokens != n_image_features: - raise ValueError( - f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" - ) - image_mask = ( - (input_ids == self.config.image_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) # type: ignore - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) - video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) - n_video_tokens = (input_ids == self.config.video_token_id).sum().item() - n_video_features = video_embeds.shape[0] - if n_video_tokens != n_video_features: - raise ValueError( - f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" - ) - video_mask = ( - (input_ids == self.config.video_token_id) - .unsqueeze(-1) - .expand_as(inputs_embeds) - .to(inputs_embeds.device) - ) - video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) - inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) # type: ignore - - if attention_mask is not None: - attention_mask = attention_mask.to(inputs_embeds.device) - - # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme - if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): - # calculate RoPE index once per generation in the pre-fill stage only - if ( - (cache_position is not None and cache_position[0] == 0) - or self.rope_deltas is None - or (past_key_values is None or past_key_values.get_seq_length() == 0) # type: ignore - ): - position_ids, rope_deltas = self.get_rope_index( - input_ids, image_grid_thw, video_grid_thw, attention_mask - ) - self.rope_deltas = rope_deltas - # then use the prev pre-calculated rope-deltas to get the correct position ids - else: - batch_size, seq_length, _ = inputs_embeds.shape - delta = ( - cache_position[0] + self.rope_deltas - if cache_position is not None - else 0 - ) - position_ids = torch.arange(seq_length, device=inputs_embeds.device) # type: ignore - position_ids = position_ids.view(1, -1).expand(batch_size, -1) # type: ignore - if cache_position is not None: # otherwise `deltas` is an int `0` - delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) # type: ignore - delta = delta.to(position_ids.device) # type: ignore - position_ids = position_ids.add(delta) # type: ignore - position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) # type: ignore - - outputs = self.model( - input_ids=None, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) - - hidden_states = outputs[0] - logits = None - loss = None - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states, - self.lm_head.weight, - labels, - _PATCH_OPTS, - ) - else: - logits = self.lm_head(hidden_states) - - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - - if not return_dict: - output = (logits,) + outputs[1:] - return (loss,) + output if loss is not None else output - - return Qwen2VLCausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - rope_deltas=self.rope_deltas, - ) - - -def patch_qwen2_vl( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen2_vl import modeling_qwen2_vl - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen2_vl.Qwen2VLForConditionalGeneration - ), f"Expected a Qwen2VLForConditionalGeneration model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward_multimodal, maybe_model) - - return maybe_model - - modeling_qwen2_vl.Qwen2VLForConditionalGeneration.forward = cce_forward_multimodal - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py deleted file mode 100644 index 799a4f357..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3.py +++ /dev/null @@ -1,35 +0,0 @@ -"""Qwen3 CCE patch. The model inherits Llama's modeling code and uses the same forward method.""" - -# pylint: disable=duplicate-code - -from types import MethodType - -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, -) - - -def patch_qwen3( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - from transformers.models.qwen3 import modeling_qwen3 - - # Set the _PATCH_OPTS in the llama patch file - import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch - - llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access - - from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import cce_forward - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen3.Qwen3ForCausalLM - ), f"Expected a Qwen3ForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(cce_forward, maybe_model) - return maybe_model - - modeling_qwen3.Qwen3ForCausalLM.forward = cce_forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py deleted file mode 100644 index 90466e64b..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py +++ /dev/null @@ -1,183 +0,0 @@ -"""Qwen3 MoE CCE patch. Adapted from transformers v4.51.2""" - -# pylint: disable=duplicate-code - -from types import MethodType -from typing import Optional, Union - -import torch -import transformers -from cut_cross_entropy.transformers.utils import ( - PatchOptions, - TransformersModelT, - apply_lce, -) -from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - KwargsForCausalLM, - MoeCausalLMOutputWithPast, - MoeModelOutputWithPast, - load_balancing_loss_func, -) -from transformers.processing_utils import Unpack -from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import can_return_tuple - -_PATCH_OPTS: PatchOptions | None = None - - -@can_return_tuple -@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[list[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_router_logits: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - logits_to_keep: Union[int, torch.Tensor] = 0, - **kwargs: Unpack[KwargsForCausalLM], -) -> MoeCausalLMOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - logits_to_keep (`int` or `torch.Tensor`, *optional*): - If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all - `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that - token can save memory, which becomes pretty significant for long sequences or large vocabulary size. - If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. - This is useful when using packed tensor format (single dimension for batch and sequence length). - - Returns: - - Example: - - ```python - >>> from transformers import AutoTokenizer, Qwen3MoeForCausalLM - - >>> model = Qwen3MoeForCausalLM.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-MoE-15B-A2B") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_router_logits = ( - output_router_logits - if output_router_logits is not None - else self.config.output_router_logits - ) - - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs: MoeModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - output_router_logits=output_router_logits, - cache_position=cache_position, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - - if hidden_states is None: - raise ValueError("hidden_states is None") - - loss = None - logits = None - - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) - if isinstance(logits_to_keep, int) - else logits_to_keep - ) - - if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training): - assert labels is not None - loss = apply_lce( - hidden_states[:, slice_indices, :], - self.lm_head.weight, - labels, - _PATCH_OPTS, - **kwargs, - ) - else: - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - if labels is not None: - loss = self.loss_function(logits, labels, self.vocab_size, **kwargs) - - aux_loss = None - if output_router_logits: - aux_loss = load_balancing_loss_func( - outputs.router_logits, - self.num_experts, - self.num_experts_per_tok, - attention_mask, - ) - if labels is not None: - loss += self.router_aux_loss_coef * aux_loss.to( # type: ignore - loss.device # type: ignore - ) # make sure to reside in the same device - - return MoeCausalLMOutputWithPast( - loss=loss, - aux_loss=aux_loss, # type: ignore - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - router_logits=outputs.router_logits, - ) - - -def patch_qwen3_moe( - maybe_model: TransformersModelT | str | transformers.PretrainedConfig, - patch_options: PatchOptions, -) -> TransformersModelT | None: - global _PATCH_OPTS # pylint: disable=global-statement - - from transformers.models.qwen3_moe import modeling_qwen3_moe - - _PATCH_OPTS = patch_options - - if isinstance(maybe_model, transformers.PreTrainedModel): - assert isinstance( - maybe_model, modeling_qwen3_moe.Qwen3MoeForCausalLM - ), f"Expected a Qwen3MoeForCausalLM model. Got {type(maybe_model)}." - maybe_model.forward = MethodType(forward, maybe_model) - - return maybe_model - - modeling_qwen3_moe.Qwen3MoeForCausalLM.forward = forward - return None diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py deleted file mode 100644 index b808b9f0d..000000000 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (C) 2024 Apple Inc. All Rights Reserved. - -"""Monkeypatch for apply_lce to add softcap.""" - -import torch -from cut_cross_entropy import linear_cross_entropy -from cut_cross_entropy.transformers.utils import PatchOptions - - -def apply_lce( - e: torch.Tensor, - c: torch.Tensor, - labels: torch.Tensor, - opts: PatchOptions, - bias: torch.Tensor | None = None, - softcap: float | None = None, - **loss_kwargs, -) -> torch.Tensor: - """Monkey patch for apply_lce to support softcap kwarg.""" - num_items_in_batch = loss_kwargs.get("num_items_in_batch", None) - cce_kwargs = opts.to_kwargs() - if num_items_in_batch is not None and cce_kwargs["reduction"] == "mean": - cce_kwargs["reduction"] = "sum" - else: - num_items_in_batch = None - - loss = linear_cross_entropy( - e, - c, - labels.to(e.device), - bias=bias, - shift=True, - softcap=softcap, - **cce_kwargs, - ) - - if num_items_in_batch is not None: - loss = loss / num_items_in_batch - - return loss diff --git a/src/axolotl/processing_strategies.py b/src/axolotl/processing_strategies.py index ce9b6a838..080697400 100644 --- a/src/axolotl/processing_strategies.py +++ b/src/axolotl/processing_strategies.py @@ -142,7 +142,7 @@ class ProcessingStrategy: # TODO: check if it's normal to be single image only for common datasets # From observation, it's usually a list of single image but some datasets may have several columns for images # Temporary solution: take the first image and suggest people convert their datasets to use multi-content Messages - if len(processed_example[image_key]) > 0: + if len(processed_example[image_key]) > 1: LOG.warning( f"Found {len(processed_example[image_key])} images in a sample. Using the first one." "If you are using a dataset with multiple images per sample, please convert it to use multi-content Messages."