Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201

Pytorch 2.0.1
This commit is contained in:
Wing Lian
2023-07-21 00:28:38 -04:00
committed by GitHub
4 changed files with 20 additions and 18 deletions

View File

@@ -18,12 +18,12 @@ jobs:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: "117" - cuda: "117"
cuda_version: 11.7.1 cuda_version: 11.7.1
@@ -33,7 +33,7 @@ jobs:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: gptq axolotl_extras: gptq
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -17,17 +17,17 @@ jobs:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117 - cuda: cu117
cuda_version: 11.7.1 cuda_version: 11.7.1
@@ -72,17 +72,17 @@ jobs:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: cu118 - cuda: cu118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.9"
pytorch: 2.0.0 pytorch: 2.0.1
axolotl_extras: gptq axolotl_extras: gptq
- cuda: cu117 - cuda: cu117
cuda_version: 11.7.1 cuda_version: 11.7.1

View File

@@ -38,8 +38,9 @@ WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/HazyResearch/flash-attention.git && \ RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
cd flash-attention && \ cd flash-attention && \
git checkout v1.0.9 && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \
cd csrc/fused_dense_lib && \ cd csrc/fused_dense_lib && \
python3 setup.py bdist_wheel && \ python3 setup.py bdist_wheel && \

View File

@@ -184,14 +184,15 @@ def sdp_attention_forward(
# We only apply sdp attention if we don't need to output the whole attention matrix # We only apply sdp attention if we don't need to output the whole attention matrix
if not output_attentions: if not output_attentions:
attn_output = torch.nn.functional.scaled_dot_product_attention( with torch.backends.cuda.sdp_kernel():
query_states, attn_output = torch.nn.functional.scaled_dot_product_attention(
key_states, query_states,
value_states, key_states,
attn_mask=attention_mask, value_states,
is_causal=False, attn_mask=attention_mask,
) is_causal=False,
attn_weights = None )
attn_weights = None
else: else:
attn_weights = torch.matmul( attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3) query_states, key_states.transpose(2, 3)