From c58034d48c089407b1353cc3a2b14bb5b59f8d14 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Jul 2023 00:47:13 -0400 Subject: [PATCH 1/3] use pytorch 2.0.1 --- .github/workflows/base.yml | 6 +++--- .github/workflows/main.yml | 12 ++++++------ 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index f3ad69570..491b51efe 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -18,12 +18,12 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: "118" cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: "117" cuda_version: 11.7.1 @@ -33,7 +33,7 @@ jobs: - cuda: "118" cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: gptq steps: - name: Checkout diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 07f25cac6..0ab14eb58 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -17,17 +17,17 @@ jobs: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: gptq - cuda: cu117 cuda_version: 11.7.1 @@ -72,17 +72,17 @@ jobs: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.10" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: - cuda: cu118 cuda_version: 11.8.0 python_version: "3.9" - pytorch: 2.0.0 + pytorch: 2.0.1 axolotl_extras: gptq - cuda: cu117 cuda_version: 11.7.1 From b06d3e364554f84be3435b8f41084ca88303a30a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Jul 2023 01:02:08 -0400 Subject: [PATCH 2/3] explicitly pin flash attention 1 to v1.0.9 --- docker/Dockerfile-base | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index adf7996ee..ff60ca504 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -38,8 +38,9 @@ WORKDIR /workspace ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" -RUN git clone https://github.com/HazyResearch/flash-attention.git && \ +RUN git clone https://github.com/Dao-AILab/flash-attention.git && \ cd flash-attention && \ + git checkout v1.0.9 && \ python3 setup.py bdist_wheel && \ cd csrc/fused_dense_lib && \ python3 setup.py bdist_wheel && \ From a032c9f452c7aa2c65078170fec2496a72222f03 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 20 Jul 2023 01:05:48 -0400 Subject: [PATCH 3/3] fix sdp attention to use the flash/mem-efficient context manaager --- .../monkeypatch/llama_attn_hijack_xformers.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py index c6bdafb89..8fa00f43b 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_xformers.py @@ -184,14 +184,15 @@ def sdp_attention_forward( # We only apply sdp attention if we don't need to output the whole attention matrix if not output_attentions: - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - is_causal=False, - ) - attn_weights = None + with torch.backends.cuda.sdp_kernel(): + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, + is_causal=False, + ) + attn_weights = None else: attn_weights = torch.matmul( query_states, key_states.transpose(2, 3)