fix for upstream refactor of KwargsForCausalLM (#2911)
This commit is contained in:
@@ -6,15 +6,21 @@ from typing import Optional, Union, Unpack
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Cache
|
from transformers import Cache
|
||||||
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
|
||||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||||
from transformers.utils import LossKwargs
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
||||||
|
from transformers.utils import LossKwargs
|
||||||
|
|
||||||
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
class TransformersKwargs(FlashAttentionKwargs, LossKwargs):
|
||||||
"""
|
"""
|
||||||
placeholder kwargs for hf model classes
|
placeholder kwargs for hf model classes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
from transformers.utils.generic import ( # type: ignore[no-redef]
|
||||||
|
TransformersKwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def kldiv_forward_llama_like(
|
def kldiv_forward_llama_like(
|
||||||
@@ -33,7 +39,7 @@ def kldiv_forward_llama_like(
|
|||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
cache_position: Optional[torch.LongTensor] = None,
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
|
||||||
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
|
**kwargs: Unpack[TransformersKwargs], # type: ignore[misc]
|
||||||
) -> CausalLMOutputWithPast:
|
) -> CausalLMOutputWithPast:
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_attentions = (
|
output_attentions = (
|
||||||
|
|||||||
Reference in New Issue
Block a user