From 34de5b3bd56da51c02872de86bed1623a8d891a5 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 26 Jan 2024 00:40:39 -0500 Subject: [PATCH] extras for the various flash attn subdirs and build those in the base module as it is a slow step --- docker/Dockerfile-base | 3 ++- setup.py | 8 +++++++- src/axolotl/monkeypatch/llama_attn_hijack_flash.py | 5 ++++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index e8f3c7f4e..8b264b489 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -29,7 +29,8 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ - python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA + python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ + python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy' RUN git lfs install --skip-repo && \ pip3 install awscli && \ diff --git a/setup.py b/setup.py index b0c9ecbc3..33dd63bb6 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,13 @@ setup( "flash-attn==2.3.3", ], "fused-dense-lib": [ - "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", + "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", + ], + "dropout-layer-norm": [ + "dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm", + ], + "xentropy-cuda-lib": [ + "xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy", ], "deepspeed": [ "deepspeed>=0.13.1", diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 8d409527e..97a392783 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -113,7 +113,10 @@ def patch_cross_entropy(): def patch_rms_norm(): try: - from flash_attn.ops.rms_norm import RMSNorm + try: + from flash_attn.ops.triton.rms_norm import RMSNorm + except ImportError: + from flash_attn.ops.rms_norm import RMSNorm class LlamaRMSNorm(RMSNorm): """Patched LLamaRMSNorm"""