Merge pull request #300 from OpenAccess-AI-Collective/pytorch-201
Pytorch 2.0.1
This commit is contained in:
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -18,12 +18,12 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.9"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: "117"
|
- cuda: "117"
|
||||||
cuda_version: 11.7.1
|
cuda_version: 11.7.1
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.9"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras: gptq
|
axolotl_extras: gptq
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
12
.github/workflows/main.yml
vendored
12
.github/workflows/main.yml
vendored
@@ -17,17 +17,17 @@ jobs:
|
|||||||
- cuda: cu118
|
- cuda: cu118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.9"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: cu118
|
- cuda: cu118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: cu118
|
- cuda: cu118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.9"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras: gptq
|
axolotl_extras: gptq
|
||||||
- cuda: cu117
|
- cuda: cu117
|
||||||
cuda_version: 11.7.1
|
cuda_version: 11.7.1
|
||||||
@@ -72,17 +72,17 @@ jobs:
|
|||||||
- cuda: cu118
|
- cuda: cu118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.9"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: cu118
|
- cuda: cu118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: cu118
|
- cuda: cu118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.9"
|
||||||
pytorch: 2.0.0
|
pytorch: 2.0.1
|
||||||
axolotl_extras: gptq
|
axolotl_extras: gptq
|
||||||
- cuda: cu117
|
- cuda: cu117
|
||||||
cuda_version: 11.7.1
|
cuda_version: 11.7.1
|
||||||
|
|||||||
@@ -38,8 +38,9 @@ WORKDIR /workspace
|
|||||||
|
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||||
|
|
||||||
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
|
RUN git clone https://github.com/Dao-AILab/flash-attention.git && \
|
||||||
cd flash-attention && \
|
cd flash-attention && \
|
||||||
|
git checkout v1.0.9 && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
cd csrc/fused_dense_lib && \
|
cd csrc/fused_dense_lib && \
|
||||||
python3 setup.py bdist_wheel && \
|
python3 setup.py bdist_wheel && \
|
||||||
|
|||||||
@@ -184,14 +184,15 @@ def sdp_attention_forward(
|
|||||||
|
|
||||||
# We only apply sdp attention if we don't need to output the whole attention matrix
|
# We only apply sdp attention if we don't need to output the whole attention matrix
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
with torch.backends.cuda.sdp_kernel():
|
||||||
query_states,
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||||
key_states,
|
query_states,
|
||||||
value_states,
|
key_states,
|
||||||
attn_mask=attention_mask,
|
value_states,
|
||||||
is_causal=False,
|
attn_mask=attention_mask,
|
||||||
)
|
is_causal=False,
|
||||||
attn_weights = None
|
)
|
||||||
|
attn_weights = None
|
||||||
else:
|
else:
|
||||||
attn_weights = torch.matmul(
|
attn_weights = torch.matmul(
|
||||||
query_states, key_states.transpose(2, 3)
|
query_states, key_states.transpose(2, 3)
|
||||||
|
|||||||
Reference in New Issue
Block a user