Compare commits
1 Commits
flex_patch
...
llama-flas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c171aadb4 |
@@ -245,7 +245,6 @@ def flashattn_forward_with_s2attn(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
cu_seqlens: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||||
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
max_seqlen: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -374,7 +373,6 @@ def flashattn_forward(
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
@@ -770,12 +768,6 @@ def llama_model_forward(
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=inputs_embeds.device,
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
padding_mask = None
|
|
||||||
else:
|
|
||||||
if 0 in attention_mask:
|
|
||||||
padding_mask = attention_mask
|
|
||||||
else:
|
|
||||||
padding_mask = None
|
|
||||||
|
|
||||||
attention_mask = (
|
attention_mask = (
|
||||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||||
@@ -825,7 +817,6 @@ def llama_model_forward(
|
|||||||
past_key_value,
|
past_key_value,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
None,
|
None,
|
||||||
padding_mask,
|
|
||||||
cu_seqlens,
|
cu_seqlens,
|
||||||
max_seqlen,
|
max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -837,7 +828,6 @@ def llama_model_forward(
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
padding_mask=padding_mask,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
@@ -884,7 +874,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
padding_mask: Optional[torch.LongTensor] = None,
|
|
||||||
cu_seqlens: Optional[torch.Tensor] = None,
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
max_seqlen: Optional[torch.Tensor] = None,
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[
|
) -> Tuple[
|
||||||
@@ -917,7 +906,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
|||||||
past_key_value=past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
padding_mask=padding_mask,
|
|
||||||
cu_seqlens=cu_seqlens,
|
cu_seqlens=cu_seqlens,
|
||||||
max_seqlen=max_seqlen,
|
max_seqlen=max_seqlen,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user