diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py index 99e17910e..ea9e10724 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/cohere.py @@ -20,25 +20,15 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.cohere.modeling_cohere import ( - _CONFIG_FOR_DOC, - COHERE_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py index 4c8d2261a..ae3d8c6ef 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma.py @@ -17,25 +17,15 @@ from cut_cross_entropy.transformers.utils import ( from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma.modeling_gemma import ( - _CONFIG_FOR_DOC, - GEMMA_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py index ccf0c160d..644e5cce7 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/gemma3.py @@ -20,15 +20,11 @@ from torch import nn from transformers.cache_utils import Cache, HybridCache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.gemma3.modeling_gemma3 import ( - _CONFIG_FOR_DOC, - GEMMA3_INPUTS_DOCSTRING, Gemma3CausalLMOutputWithPast, logger, ) from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @@ -38,10 +34,6 @@ _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -170,10 +162,6 @@ def cce_forward( @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py index 42ab996b9..bed411ace 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama.py @@ -19,15 +19,9 @@ from transformers.modeling_outputs import ( CausalLMOutputWithPast, ) from transformers.models.llama.modeling_llama import ( - _CONFIG_FOR_DOC, - LLAMA_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -36,10 +30,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py index 7204f5c90..3143e9c8d 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/llama4.py @@ -16,22 +16,12 @@ from torch import nn from transformers.cache_utils import Cache from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama4.modeling_llama4 import ( - _CONFIG_FOR_DOC, - LLAMA4_INPUTS_DOCSTRING, Llama4CausalLMOutputWithPast, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) _PATCH_OPTS: PatchOptions | None = None -@add_start_docstrings_to_model_forward(LLAMA4_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, @@ -160,9 +150,6 @@ def cce_forward( ) -@replace_return_docstrings( - output_type=Llama4CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: torch.LongTensor | None = None, # type: ignore diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py index adb65fa8f..aa252701e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/mistral3.py @@ -19,15 +19,11 @@ from transformers.models.mistral3.modeling_mistral3 import ( Mistral3CausalLMOutputWithPast, ) from transformers.models.mistral.modeling_mistral import ( - _CONFIG_FOR_DOC, - MISTRAL_INPUTS_DOCSTRING, KwargsForCausalLM, ) from transformers.processing_utils import Unpack from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @@ -35,10 +31,6 @@ _PATCH_OPTS: PatchOptions | None = None @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward( self, input_ids: torch.LongTensor | None = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py index 0811bf55a..afe56266e 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_moe.py @@ -13,16 +13,10 @@ from cut_cross_entropy.transformers.utils import ( apply_lce, ) from transformers.models.qwen2_moe.modeling_qwen2_moe import ( - _CONFIG_FOR_DOC, - QWEN2MOE_INPUTS_DOCSTRING, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, load_balancing_loss_func, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -31,10 +25,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py index 250c3ab6b..79af01cfa 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen2_vl.py @@ -14,22 +14,12 @@ from cut_cross_entropy.transformers.utils import ( ) from torch.nn import CrossEntropyLoss from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - _CONFIG_FOR_DOC, - QWEN2_VL_INPUTS_DOCSTRING, Qwen2VLCausalLMOutputWithPast, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) _PATCH_OPTS: PatchOptions | None = None -@add_start_docstrings_to_model_forward(QWEN2_VL_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Qwen2VLCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def cce_forward_multimodal( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py index c5cd76f94..90466e64b 100644 --- a/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py +++ b/src/axolotl/integrations/cut_cross_entropy/monkeypatch/qwen3_moe.py @@ -12,20 +12,13 @@ from cut_cross_entropy.transformers.utils import ( TransformersModelT, apply_lce, ) -from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.qwen3_moe.modeling_qwen3_moe import ( - _CONFIG_FOR_DOC, - QWEN3_MOE_INPUTS_DOCSTRING, KwargsForCausalLM, MoeCausalLMOutputWithPast, MoeModelOutputWithPast, load_balancing_loss_func, ) from transformers.processing_utils import Unpack -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) from transformers.utils.deprecation import deprecate_kwarg from transformers.utils.generic import can_return_tuple @@ -34,10 +27,6 @@ _PATCH_OPTS: PatchOptions | None = None @can_return_tuple @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(QWEN3_MOE_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/axolotl/integrations/liger/models/deepseekv2.py b/src/axolotl/integrations/liger/models/deepseekv2.py index c29fd4e79..2f0d2a704 100644 --- a/src/axolotl/integrations/liger/models/deepseekv2.py +++ b/src/axolotl/integrations/liger/models/deepseekv2.py @@ -14,10 +14,6 @@ from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import CausalLMOutputWithPast -# @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) -# @replace_return_docstrings( -# output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -# ) def lce_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/integrations/liger/models/jamba.py b/src/axolotl/integrations/liger/models/jamba.py index 7ab464c88..d25529970 100644 --- a/src/axolotl/integrations/liger/models/jamba.py +++ b/src/axolotl/integrations/liger/models/jamba.py @@ -13,21 +13,11 @@ from liger_kernel.transformers.fused_linear_cross_entropy import ( from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.models.jamba.modeling_jamba import ( - _CONFIG_FOR_DOC, - JAMBA_INPUTS_DOCSTRING, HybridMambaAttentionDynamicCache, load_balancing_loss_func, ) -from transformers.utils import ( - add_start_docstrings_to_model_forward, - replace_return_docstrings, -) -@add_start_docstrings_to_model_forward(JAMBA_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def lce_forward( self, input_ids: torch.LongTensor = None, diff --git a/src/axolotl/monkeypatch/gemma3.py b/src/axolotl/monkeypatch/gemma3.py index 38183fa0e..36f591efd 100644 --- a/src/axolotl/monkeypatch/gemma3.py +++ b/src/axolotl/monkeypatch/gemma3.py @@ -7,24 +7,16 @@ from typing import Optional, Tuple, Union import torch from transformers.cache_utils import Cache from transformers.models.gemma3.modeling_gemma3 import ( - _CONFIG_FOR_DOC, - GEMMA3_INPUTS_DOCSTRING, Gemma3CausalLMOutputWithPast, logger, ) from transformers.utils import ( - add_start_docstrings_to_model_forward, is_torchdynamo_compiling, - replace_return_docstrings, ) from transformers.utils.deprecation import deprecate_kwarg @deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep") -@add_start_docstrings_to_model_forward(GEMMA3_INPUTS_DOCSTRING) -@replace_return_docstrings( - output_type=Gemma3CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC -) def new_forward( self, input_ids: torch.LongTensor = None,