Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201

Pytorch 2.0.1
This commit is contained in:
Wing Lian
2023-07-21 00:28:38 -04:00
committed by GitHub
4 changed files with 20 additions and 18 deletions

View File

@@ -18,12 +18,12 @@ jobs:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: "117"
cuda_version: 11.7.1
@@ -33,7 +33,7 @@ jobs:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras: gptq
steps:
- name: Checkout

View File

@@ -17,17 +17,17 @@ jobs:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.1
@@ -72,17 +72,17 @@ jobs:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras:
- cuda: cu118
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.0
pytorch: 2.0.1
axolotl_extras: gptq
- cuda: cu117
cuda_version: 11.7.1

View File

@@ -38,8 +38,9 @@ WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \
git checkout v1.0.9 && \
python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \

View File

@@ -184,14 +184,15 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
attn_weights = None
else:
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)