Merge pull request #299 from OpenAccess-AI-Collective/flash-attention-2

Flash attention 2
This commit is contained in:
Wing Lian
2023-07-22 04:07:48 -04:00
committed by GitHub
2 changed files with 4 additions and 4 deletions

View File

@@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v1.0.9 && \
git checkout 9ee0ff1 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \

View File

@@ -8,7 +8,7 @@ import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
@@ -79,7 +79,7 @@ def forward(
dtype=torch.int32,
device=qkv.device,
)
output = flash_attn_unpadded_qkvpacked_func(
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
@@ -95,7 +95,7 @@ def forward(
three=3,
h=nheads,
)
output_unpad = flash_attn_unpadded_qkvpacked_func(
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,