From 9b790d359b2a901bc731a7b4bfbbab73fca16235 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Jul 2023 00:00:49 -0400 Subject: [PATCH 1/2] flash attention 2 --- docker/Dockerfile-base | 2 +- src/axolotl/flash_attn.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index ff60ca504..be67e8eb4 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ cd flash-attention && \ - git checkout v1.0.9 && \ + git checkout v2.0.0 && \ python3 setup.py bdist_wheel && \ cd csrc/fused_dense_lib && \ python3 setup.py bdist_wheel && \ diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index 406dd15ad..073786882 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -8,7 +8,7 @@ import torch import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from transformers.models.llama.modeling_llama import apply_rotary_pos_emb @@ -79,7 +79,7 @@ def forward( dtype=torch.int32, device=qkv.device, ) - output = flash_attn_unpadded_qkvpacked_func( + output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) @@ -95,7 +95,7 @@ def forward( three=3, h=nheads, ) - output_unpad = flash_attn_unpadded_qkvpacked_func( + output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_q_lens, max_s, From cdf85fdbd5679c17cd28772445680e4b44babefd Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 21 Jul 2023 08:18:53 -0400 Subject: [PATCH 2/2] pin flash attention 2 to the fix for backwards pass --- docker/Dockerfile-base | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index be67e8eb4..6862569b9 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ cd flash-attention && \ - git checkout v2.0.0 && \ + git checkout 9ee0ff1 && \ python3 setup.py bdist_wheel && \ cd csrc/fused_dense_lib && \ python3 setup.py bdist_wheel && \