fix for upstream refactor of KwargsForCausalLM (#2911)
This commit is contained in:
@@ -6,15 +6,21 @@ from typing import Optional, Union, Unpack
|
||||
|
||||
import torch
|
||||
from transformers import Cache
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
try:
|
||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||
from transformers.utils import LossKwargs
|
||||
|
||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
|
||||
"""
|
||||
placeholder kwargs for hf model classes
|
||||
"""
|
||||
|
||||
except ImportError:
|
||||
from transformers.utils.generic import ( # type: ignore[no-redef]
|
||||
TransformersKwargs,
|
||||
)
|
||||
|
||||
|
||||
def kldiv_forward_llama_like(
|
||||
@@ -33,7 +39,7 @@ def kldiv_forward_llama_like(
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
||||
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
|
||||
) -> CausalLMOutputWithPast:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
|
||||
Reference in New Issue
Block a user