use updated version of prebuilt wheels for flash attention for cu130 (#3342)
* use updated version of prebuilt wheels for flash attention for cu130 * use elif * fix the uv base installs of FA also * make wget less verbose
This commit is contained in:
20
.github/workflows/main.yml
vendored
20
.github/workflows/main.yml
vendored
@@ -31,11 +31,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
# - cuda: 130
|
- cuda: 130
|
||||||
# cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
# python_version: "3.11"
|
python_version: "3.11"
|
||||||
# pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
# axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -98,11 +98,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
# - cuda: 130
|
- cuda: 130
|
||||||
# cuda_version: 13.0.0
|
cuda_version: 13.0.0
|
||||||
# python_version: "3.11"
|
python_version: "3.11"
|
||||||
# pytorch: 2.9.1
|
pytorch: 2.9.1
|
||||||
# axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -7,9 +7,9 @@ FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION A
|
|||||||
|
|
||||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
ARG PYTHON_VERSION="3.10"
|
ARG PYTHON_VERSION="3.11"
|
||||||
ARG PYTORCH_VERSION="2.1.2"
|
ARG PYTORCH_VERSION="2.1.2"
|
||||||
ARG CUDA="118"
|
ARG CUDA="128"
|
||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
@@ -51,8 +51,16 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
pip3 install -U --no-cache-dir pydantic==1.10.10 && \
|
||||||
pip3 cache purge
|
pip3 cache purge
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" =~ ^2\.9\.[0-9]+$ ] && [ "$CUDA" = "128" ] ; then \
|
RUN case "$PYTORCH_VERSION" in \
|
||||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
2.9.[0-9]*) \
|
||||||
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
if [ "$CUDA" = "128" ]; then \
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
fi
|
pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
elif [ "$CUDA" = "130" ]; then \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi \
|
||||||
|
;; \
|
||||||
|
esac
|
||||||
|
|||||||
@@ -35,8 +35,16 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.9.0" ] && [ "$CUDA" = "128" ] ; then \
|
RUN case "$PYTORCH_VERSION" in \
|
||||||
wget https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.17/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
2.9.[0-9]*) \
|
||||||
uv pip install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
if [ "$CUDA" = "128" ]; then \
|
||||||
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
fi
|
uv pip3 install --no-cache-dir flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu128torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
elif [ "$CUDA" = "130" ]; then \
|
||||||
|
wget -nv https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.5.4/flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
uv pip3 install --no-cache-dir flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
rm flash_attn-2.8.3+cu130torch2.9-cp311-cp311-linux_x86_64.whl; \
|
||||||
|
fi \
|
||||||
|
;; \
|
||||||
|
esac
|
||||||
|
|||||||
Reference in New Issue
Block a user