diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index f3ad69570..491b51efe 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -18,12 +18,12 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: "118" cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: "117" cuda_version: 11.7.1 @@ -33,7 +33,7 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: gptq steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 07f25cac6..0ab14eb58 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -17,17 +17,17 @@ jobs: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: gptq - cuda: cu117 cuda_version: 11.7.1 @@ -72,17 +72,17 @@ jobs: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: gptq - cuda: cu117 cuda_version: 11.7.1 diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index adf7996ee..ff60ca504 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -38,8 +38,9 @@ WORKDIR /workspace ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" -RUN git clone https://github.com/HazyResearch/flash-attention.git && \ +RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ cd flash-attention && \ + git checkout v1.0.9 && \ python3 setup.py bdist_wheel && \ cd csrc/fused_dense_lib && \ python3 setup.py bdist_wheel && \ diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index c6bdafb89..8fa00f43b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -184,14 +184,15 @@ def sdp_attention_forward( # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - attn_weights = None + with torch.backends.cuda.sdp_kernel(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) + attn_weights = None else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3)