diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index ff60ca504..6862569b9 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -40,7 +40,7 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ cd flash-attention && \ - git checkout v1.0.9 && \ + git checkout 9ee0ff1 && \ python3 setup.py bdist_wheel && \ cd csrc/fused_dense_lib && \ python3 setup.py bdist_wheel && \ diff --git a/src/axolotl/flash_attn.py b/src/axolotl/flash_attn.py index 406dd15ad..073786882 100644 --- a/src/axolotl/flash_attn.py +++ b/src/axolotl/flash_attn.py @@ -8,7 +8,7 @@ import torch import transformers from einops import rearrange from flash_attn.bert_padding import pad_input, unpad_input -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func +from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func from transformers.models.llama.modeling_llama import apply_rotary_pos_emb @@ -79,7 +79,7 @@ def forward( dtype=torch.int32, device=qkv.device, ) - output = flash_attn_unpadded_qkvpacked_func( + output = flash_attn_varlen_qkvpacked_func( qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True ) output = rearrange(output, "(b s) ... -> b s ...", b=bsz) @@ -95,7 +95,7 @@ def forward( three=3, h=nheads, ) - output_unpad = flash_attn_unpadded_qkvpacked_func( + output_unpad = flash_attn_varlen_qkvpacked_func( x_unpad, cu_q_lens, max_s,