diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index e8f3c7f4e..8b264b489 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -29,7 +29,8 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}" WORKDIR /workspace RUN python3 -m pip install --upgrade pip && pip3 install packaging && \ - python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA + python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \ + python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy' RUN git lfs install --skip-repo && \ pip3 install awscli && \ diff --git a/setup.py b/setup.py index b0c9ecbc3..33dd63bb6 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,13 @@ setup( "flash-attn==2.3.3", ], "fused-dense-lib": [ - "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", + "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", + ], + "dropout-layer-norm": [ + "dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm", + ], + "xentropy-cuda-lib": [ + "xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy", ], "deepspeed": [ "deepspeed>=0.13.1", diff --git a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py index 8d409527e..97a392783 100644 --- a/src/axolotl/monkeypatch/llama_attn_hijack_flash.py +++ b/src/axolotl/monkeypatch/llama_attn_hijack_flash.py @@ -113,7 +113,10 @@ def patch_cross_entropy(): def patch_rms_norm(): try: - from flash_attn.ops.rms_norm import RMSNorm + try: + from flash_attn.ops.triton.rms_norm import RMSNorm + except ImportError: + from flash_attn.ops.rms_norm import RMSNorm class LlamaRMSNorm(RMSNorm): """Patched LLamaRMSNorm"""