don't worry about dupes
This commit is contained in:
@@ -25,6 +25,7 @@ def forward(
|
|||||||
|
|
||||||
attention_mask: [bsz, q_len]
|
attention_mask: [bsz, q_len]
|
||||||
"""
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = (
|
query_states = (
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ def xformers_forward(
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = (
|
query_states = (
|
||||||
@@ -143,6 +144,7 @@ def sdp_attention_forward(
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
use_cache: bool = False,
|
use_cache: bool = False,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = (
|
query_states = (
|
||||||
|
|||||||
Reference in New Issue
Block a user