update for recent transformers updates (#636)

* update for recent transformers updates

* fix checkpoint forward kwargs

* just pass args into torch checkpoint
This commit is contained in:
Wing Lian
2023-09-27 12:10:32 -04:00
committed by GitHub
parent e8cbf50be6
commit 60c7c48c97

View File

@@ -99,6 +99,7 @@ def flashattn_forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
@@ -476,6 +477,13 @@ def llama_model_forward(
dtype=torch.bool,
device=inputs_embeds.device,
)
padding_mask = None
else:
if 0 in attention_mask:
padding_mask = attention_mask
else:
padding_mask = None
attention_mask = (
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
attention_mask,
@@ -510,7 +518,9 @@ def llama_model_forward(
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs)
return module(
*inputs,
)
return custom_forward
@@ -519,9 +529,10 @@ def llama_model_forward(
hidden_states,
attention_mask,
position_ids,
None,
past_key_value,
output_attentions,
None,
padding_mask,
cu_seqlens,
max_seqlen,
)
@@ -533,6 +544,7 @@ def llama_model_forward(
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
@@ -579,6 +591,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[torch.Tensor] = None,
) -> Tuple[
@@ -611,6 +624,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)