monkeypatch.mistral_attn_hijack_flash
monkeypatch.mistral_attn_hijack_flash
Flash attention monkey patch for mistral model
Classes
| MistralDecoderLayer |
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens |
MistralDecoderLayer
monkeypatch.mistral_attn_hijack_flash.MistralDecoderLayer()
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
Methods
forward
monkeypatch.mistral_attn_hijack_flash.MistralDecoderLayer.forward(
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
cu_seqlens=None,
max_seqlen=None,
)
Parameters
| hidden_states |
torch.FloatTensor |
input to the layer of shape (batch, seq_len, embed_dim) |
required |
| attention_mask |
torch.FloatTensor, optional |
attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values. |
None |
| output_attentions |
bool, optional |
Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail. |
False |
| use_cache |
bool, optional |
If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values). |
False |
| past_key_value |
Tuple(torch.FloatTensor), optional |
cached past key and value projection states |
None |
Functions
generate_qkv
monkeypatch.mistral_attn_hijack_flash.generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
)
Parameters
| q |
|
(batch_size, seqlen_q, nheads, d) |
required |
| k |
|
(batch_size, seqlen_k, nheads_k, d) |
required |
| v |
|
(batch_size, seqlen_k, nheads_k, d) |
required |
| query_padding_mask |
|
(batch_size, seqlen), bool |
None |
| key_padding_mask |
|
(batch_size, seqlen), bool |
None |