update for recent transformers updates (#636)
* update for recent transformers updates * fix checkpoint forward kwargs * just pass args into torch checkpoint
This commit is contained in:
@@ -99,6 +99,7 @@ def flashattn_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -476,6 +477,13 @@ def llama_model_forward(
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
|
padding_mask = None
|
||||||
|
else:
|
||||||
|
if 0 in attention_mask:
|
||||||
|
padding_mask = attention_mask
|
||||||
|
else:
|
||||||
|
padding_mask = None
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
attention_mask,
|
attention_mask,
|
||||||
@@ -510,7 +518,9 @@ def llama_model_forward(
|
|||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs)
|
return module(
|
||||||
|
*inputs,
|
||||||
|
)
|
||||||
|
|
||||||
return custom_forward
|
return custom_forward
|
||||||
|
|
||||||
@@ -519,9 +529,10 @@ def llama_model_forward(
|
|||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
position_ids,
|
position_ids,
|
||||||
None,
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
None,
|
None,
|
||||||
|
padding_mask,
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_seqlen,
|
max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -533,6 +544,7 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -579,6 +591,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
|
padding_mask: Optional[torch.LongTensor] = None,
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@@ -611,6 +624,7 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
padding_mask=padding_mask,
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user