From d320ef619988d6cf5151d5a9da88979dc7f91bfe Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 15 Jul 2025 11:28:41 -0400 Subject: [PATCH] fix for upstream refactor of KwargsForCausalLM (#2911) --- src/axolotl/integrations/kd/kernels/models.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py index 5a7c286bc..6a8b6da1c 100644 --- a/src/axolotl/integrations/kd/kernels/models.py +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -6,15 +6,21 @@ from typing import Optional, Union, Unpack import torch from transformers import Cache -from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_outputs import CausalLMOutputWithPast -from transformers.utils import LossKwargs +try: + from transformers.modeling_flash_attention_utils import FlashAttentionKwargs + from transformers.utils import LossKwargs -class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): - """ - placeholder kwargs for hf model classes - """ + class TransformersKwargs(FlashAttentionKwargs, LossKwargs): + """ + placeholder kwargs for hf model classes + """ + +except ImportError: + from transformers.utils.generic import ( # type: ignore[no-redef] + TransformersKwargs, + ) def kldiv_forward_llama_like( @@ -33,7 +39,7 @@ def kldiv_forward_llama_like( output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument - **kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc] + **kwargs: Unpack[TransformersKwargs], # type: ignore[misc] ) -> CausalLMOutputWithPast: # pylint: disable=duplicate-code output_attentions = (