monkeypatch.llama_attn_hijack_flash
monkeypatch.llama_attn_hijack_flash
Flash attention monkey patch for llama model
Classes
| Name | Description |
|---|---|
| FusedAttention | Fused QKV Attention layer for incrementally improved training efficiency |
| LlamaDecoderLayer | patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens |
FusedAttention
monkeypatch.llama_attn_hijack_flash.FusedAttention(config, q, k, v, o)Fused QKV Attention layer for incrementally improved training efficiency
LlamaDecoderLayer
monkeypatch.llama_attn_hijack_flash.LlamaDecoderLayer()patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
Methods
| Name | Description |
|---|---|
| forward |
forward
monkeypatch.llama_attn_hijack_flash.LlamaDecoderLayer.forward(
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
padding_mask=None,
cu_seqlens=None,
max_seqlen=None,
)Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| hidden_states | torch.FloatTensor |
input to the layer of shape (batch, seq_len, embed_dim) |
required |
| attention_mask | torch.FloatTensor, optional |
attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values. |
None |
| output_attentions | bool, optional |
Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail. |
False |
| use_cache | bool, optional |
If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values). |
False |
| past_key_value | Tuple(torch.FloatTensor), optional |
cached past key and value projection states | None |
Functions
| Name | Description |
|---|---|
| flashattn_forward | Input shape: Batch x Time x Channel |
| flashattn_forward_with_s2attn | Input shape: Batch x Time x Channel |
| generate_qkv |
flashattn_forward
monkeypatch.llama_attn_hijack_flash.flashattn_forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
padding_mask=None,
cu_seqlens=None,
max_seqlen=None,
)Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
flashattn_forward_with_s2attn
monkeypatch.llama_attn_hijack_flash.flashattn_forward_with_s2attn(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
output_attentions=False,
use_cache=False,
padding_mask=None,
cu_seqlens=None,
max_seqlen=None,
)Input shape: Batch x Time x Channel
From: https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py
attention_mask: [bsz, q_len]
cu_seqlens will be ignored if provided
max_seqlen will be ignored if provided
generate_qkv
monkeypatch.llama_attn_hijack_flash.generate_qkv(
q,
k,
v,
query_padding_mask=None,
key_padding_mask=None,
kvpacked=False,
qkvpacked=False,
)Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| q | (batch_size, seqlen_q, nheads, d) | required | |
| k | (batch_size, seqlen_k, nheads_k, d) | required | |
| v | (batch_size, seqlen_k, nheads_k, d) | required | |
| query_padding_mask | (batch_size, seqlen), bool | None |
|
| key_padding_mask | (batch_size, seqlen), bool | None |