extras for the various flash attn subdirs and build those in the base module as it is a slow step
This commit is contained in:
@@ -29,7 +29,8 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
|
python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy'
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
RUN git lfs install --skip-repo && \
|
||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
|
|||||||
8
setup.py
8
setup.py
@@ -53,7 +53,13 @@ setup(
|
|||||||
"flash-attn==2.3.3",
|
"flash-attn==2.3.3",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||||
|
],
|
||||||
|
"dropout-layer-norm": [
|
||||||
|
"dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm",
|
||||||
|
],
|
||||||
|
"xentropy-cuda-lib": [
|
||||||
|
"xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed>=0.13.1",
|
"deepspeed>=0.13.1",
|
||||||
|
|||||||
@@ -113,7 +113,10 @@ def patch_cross_entropy():
|
|||||||
|
|
||||||
def patch_rms_norm():
|
def patch_rms_norm():
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
try:
|
||||||
|
from flash_attn.ops.triton.rms_norm import RMSNorm
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
class LlamaRMSNorm(RMSNorm):
|
||||||
"""Patched LLamaRMSNorm"""
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|||||||
Reference in New Issue
Block a user