diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index fc8d854d4..052f9aa72 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -31,11 +31,11 @@ jobs: python_version: "3.11" pytorch: 2.9.1 axolotl_extras: -# - cuda: 130 -# cuda_version: 13.0.0 -# python_version: "3.11" -# pytorch: 2.9.1 -# axolotl_extras: + - cuda: 130 + cuda_version: 13.0.0 + python_version: "3.11" + pytorch: 2.9.1 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -98,11 +98,11 @@ jobs: python_version: "3.11" pytorch: 2.9.1 axolotl_extras: -# - cuda: 130 -# cuda_version: 13.0.0 -# python_version: "3.11" -# pytorch: 2.9.1 -# axolotl_extras: + - cuda: 130 + cuda_version: 13.0.0 + python_version: "3.11" + pytorch: 2.9.1 + axolotl_extras: runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 950e69285..e080758c2 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -7,9 +7,9 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A ENV PATH="/root/miniconda3/bin:${PATH}" -ARG PYTHON_VERSION="3.10" +ARG PYTHON_VERSION="3.11" ARG PYTORCH_VERSION="2.1.2" -ARG CUDA="118" +ARG CUDA="128" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ENV PYTHON_VERSION=$PYTHON_VERSION @@ -51,8 +51,16 @@ RUN git lfs install --skip-repo && \ pip3 install -U --no-cache-dir pydantic==1.10.10 && \ pip3 cache purge -RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \ - wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - fi +RUN case "$PYTORCH_VERSION" in \ + 2.9.[0-9]*) \ + if [ "$CUDA" = "128" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + elif [ "$CUDA" = "130" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + fi \ + ;; \ + esac diff --git a/docker/Dockerfile-uv-base b/docker/Dockerfile-uv-base index 2ca272c6e..0b4dfc33f 100644 --- a/docker/Dockerfile-uv-base +++ b/docker/Dockerfile-uv-base @@ -35,8 +35,16 @@ RUN uv pip install packaging setuptools wheel psutil \ && uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \ && uv pip install awscli pydantic -RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \ - wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ - fi +RUN case "$PYTORCH_VERSION" in \ + 2.9.[0-9]*) \ + if [ "$CUDA" = "128" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + uv pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \ + elif [ "$CUDA" = "130" ]; then \ + wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + uv pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \ + fi \ + ;; \ + esac