upgrade to flash-attn 2.8.0.post2 (#2828)
* upgrade to flash-attn 2.8.0.post2 * use cu126 with torch 2.6 * seems vllm 0.8.5.post1 not compatible with cuda12.6.3 and torch 2.6 * cu126 + torch 2.6 as the default * use cu126 for multigpu w torch 2.6 too * drop vllm for now from ci for now
This commit is contained in:
13
.github/workflows/main.yml
vendored
13
.github/workflows/main.yml
vendored
@@ -20,12 +20,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -88,8 +87,8 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
@@ -146,8 +145,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,11 +26,11 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
|
|||||||
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -195,12 +195,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -247,8 +247,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
@@ -311,7 +311,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -37,7 +37,3 @@ RUN git lfs install --skip-repo && \
|
|||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
|
||||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
|
||||||
fi
|
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -111,9 +111,9 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.7.4.post1",
|
"flash-attn==2.8.0.post2",
|
||||||
"ring-flash-attn>=0.1.4",
|
"ring-flash-attn>=0.1.4",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user