Compare commits
2 Commits
flash-attn
...
sdpa-multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a538be9c2 | ||
|
|
74c72ca5eb |
21
.github/workflows/base.yml
vendored
21
.github/workflows/base.yml
vendored
@@ -1,10 +1,7 @@
|
||||
name: ci-cd-base
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- "main-base"
|
||||
- "dev-base"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-base:
|
||||
@@ -15,11 +12,6 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
@@ -28,12 +20,17 @@ jobs:
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
pytorch: 2.1.2
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
pytorch: 2.1.2
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "121"
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -56,7 +53,7 @@ jobs:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-args: |
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
|
||||
33
.github/workflows/main.yml
vendored
33
.github/workflows/main.yml
vendored
@@ -4,6 +4,7 @@ on:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build-axolotl:
|
||||
@@ -15,24 +16,24 @@ jobs:
|
||||
include:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
@@ -86,24 +87,24 @@ jobs:
|
||||
include:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.9"
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.1
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
- cuda: 121
|
||||
cuda_version: 12.1.0
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
|
||||
@@ -29,8 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy'
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
|
||||
8
setup.py
8
setup.py
@@ -53,13 +53,7 @@ setup(
|
||||
"flash-attn==2.3.3",
|
||||
],
|
||||
"fused-dense-lib": [
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"dropout-layer-norm": [
|
||||
"dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm",
|
||||
],
|
||||
"xentropy-cuda-lib": [
|
||||
"xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy",
|
||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed>=0.13.1",
|
||||
|
||||
@@ -90,46 +90,37 @@ def replace_llama_attn_with_flash_attn(
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
patch_cross_entropy()
|
||||
|
||||
if rms_norm:
|
||||
patch_rms_norm()
|
||||
|
||||
|
||||
def patch_cross_entropy():
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
|
||||
def patch_rms_norm():
|
||||
try:
|
||||
try:
|
||||
from flash_attn.ops.triton.rms_norm import RMSNorm
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
|
||||
@@ -24,12 +24,6 @@ from transformers import ( # noqa: F401
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_cross_entropy as llama_patch_cross_entropy,
|
||||
)
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
patch_rms_norm as llama_patch_rms_norm,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
@@ -287,7 +281,15 @@ def load_model(
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
if cfg.s2_attention:
|
||||
if cfg.sample_packing:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=True,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
)
|
||||
elif cfg.s2_attention:
|
||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=False,
|
||||
@@ -295,21 +297,6 @@ def load_model(
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
use_shifted_sparse_attn=True,
|
||||
)
|
||||
elif cfg.device not in ["mps", "cpu"] and not inference:
|
||||
if cfg.sample_packing:
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=True,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attn_cross_entropy:
|
||||
llama_patch_cross_entropy()
|
||||
|
||||
if cfg.flash_attn_rms_norm:
|
||||
llama_patch_rms_norm()
|
||||
|
||||
elif cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
|
||||
@@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase):
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
def test_output_multipack(self):
|
||||
mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]])
|
||||
dtype = torch.float32
|
||||
expected_output = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00],
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user