Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201
Pytorch 2.0.1
This commit is contained in:
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: "117"
|
||||
cuda_version: 11.7.1
|
||||
@@ -33,7 +33,7 @@ jobs:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -17,17 +17,17 @@ jobs:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.1
|
||||
@@ -72,17 +72,17 @@ jobs:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: cu118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.0
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras: gptq
|
||||
- cuda: cu117
|
||||
cuda_version: 11.7.1
|
||||
|
||||
@@ -38,8 +38,9 @@ WORKDIR /workspace
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
|
||||
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
||||
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||
cd flash-attention && \
|
||||
git checkout v1.0.9 && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
cd csrc/fused_dense_lib && \
|
||||
python3 setup.py bdist_wheel && \
|
||||
|
||||
@@ -184,14 +184,15 @@ def sdp_attention_forward(
|
||||
|
||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||
if not output_attentions:
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
|
||||
Reference in New Issue
Block a user