flash attention 2
This commit is contained in:
@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|||||||
|
|
||||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||||
cd flash-attention && \
|
cd flash-attention && \
|
||||||
git checkout v1.0.9 && \
|
git checkout v2.0.0 && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
cd csrc/fused_dense_lib && \
|
cd csrc/fused_dense_lib && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
from flash_attn.bert_padding import pad_input, unpad_input
|
||||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ def forward(
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=qkv.device,
|
device=qkv.device,
|
||||||
)
|
)
|
||||||
output = flash_attn_unpadded_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
@@ -95,7 +95,7 @@ def forward(
|
|||||||
three=3,
|
three=3,
|
||||||
h=nheads,
|
h=nheads,
|
||||||
)
|
)
|
||||||
output_unpad = flash_attn_unpadded_qkvpacked_func(
|
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||||
x_unpad,
|
x_unpad,
|
||||||
cu_q_lens,
|
cu_q_lens,
|
||||||
max_s,
|
max_s,
|
||||||
|
|||||||
Reference in New Issue
Block a user