From 9b0be4f15cde91a509bf4ecd4c61e50fd6d519d5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 16 May 2025 14:54:24 -0400 Subject: [PATCH] fix 12.8 image and add flash-attn v3 hopper base image --- .github/workflows/base.yml | 9 ++++++++- docker/Dockerfile-base | 7 ++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 9e19114d7..5e292ed42 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -47,11 +47,18 @@ jobs: pytorch: 2.7.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "128" - cuda_version: 12.6.3 + cuda_version: 12.8.1 cudnn_version: "" python_version: "3.11" pytorch: 2.7.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "128" + cuda_version: 12.8.1 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.7.0 + suffix: "-hopper" + torch_cuda_arch_list: "9.0+PTX" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index cf1af9682..ccb87456a 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -38,6 +38,11 @@ RUN git lfs install --skip-repo && \ # The base image ships with `pydantic==1.8.2` which is not working pip3 install -U --no-cache-dir pydantic==1.10.10 -RUN if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ +RUN if [ "$TORCH_CUDA_ARCH_LIST" = "9.0+PTX" ] ; then \ + git clone https://github.com/Dao-AILab/flash-attention.git; \ + git checkout v2.7.4.post1; \ + cd flash-attention/hopper; \ + FLASH_ATTENTION_DISABLE_SM80=TRUE FLASH_ATTENTION_DISABLE_FP8=TRUE MAX_JOBS=128 python setup.py install; \ + elif if [ "$PYTORCH_VERSION" = "2.7.0" ] ; then \ pip3 install flash-attn==2.7.4.post1; \ fi