Compare commits
2 Commits
34de5b3bd5
...
sdpa-multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a538be9c2 | ||
|
|
74c72ca5eb |
21
.github/workflows/base.yml
vendored
21
.github/workflows/base.yml
vendored
@@ -1,10 +1,7 @@
|
|||||||
name: ci-cd-base
|
name: ci-cd-base
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
workflow_dispatch:
|
||||||
branches:
|
|
||||||
- "main-base"
|
|
||||||
- "dev-base"
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-base:
|
build-base:
|
||||||
@@ -15,11 +12,6 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: "118"
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.9"
|
|
||||||
pytorch: 2.0.1
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
@@ -28,12 +20,17 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
- cuda: "121"
|
- cuda: "121"
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.1.2
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -56,7 +53,7 @@ jobs:
|
|||||||
context: .
|
context: .
|
||||||
file: ./docker/Dockerfile-base
|
file: ./docker/Dockerfile-base
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
build-args: |
|
build-args: |
|
||||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||||
|
|||||||
33
.github/workflows/main.yml
vendored
33
.github/workflows/main.yml
vendored
@@ -4,6 +4,7 @@ on:
|
|||||||
push:
|
push:
|
||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
build-axolotl:
|
build-axolotl:
|
||||||
@@ -15,24 +16,24 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
@@ -86,24 +87,24 @@ jobs:
|
|||||||
include:
|
include:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.9"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.1
|
|
||||||
axolotl_extras:
|
|
||||||
- cuda: 121
|
- cuda: 121
|
||||||
cuda_version: 12.1.0
|
cuda_version: 12.1.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.1
|
pytorch: 2.1.2
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.1.2
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -29,8 +29,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||||
python3 -m pip install flash-attn==2.3.3 'fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib' 'dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm' 'xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy'
|
|
||||||
|
|
||||||
RUN git lfs install --skip-repo && \
|
RUN git lfs install --skip-repo && \
|
||||||
pip3 install awscli && \
|
pip3 install awscli && \
|
||||||
|
|||||||
8
setup.py
8
setup.py
@@ -53,13 +53,7 @@ setup(
|
|||||||
"flash-attn==2.3.3",
|
"flash-attn==2.3.3",
|
||||||
],
|
],
|
||||||
"fused-dense-lib": [
|
"fused-dense-lib": [
|
||||||
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
|
||||||
],
|
|
||||||
"dropout-layer-norm": [
|
|
||||||
"dropout-layer-norm @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/layer_norm",
|
|
||||||
],
|
|
||||||
"xentropy-cuda-lib": [
|
|
||||||
"xentropy_cuda_lib @ git+https://github.com/Dao-AILab/flash-attention.git@2.3.3#&subdirectory=csrc/xentropy",
|
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed>=0.13.1",
|
"deepspeed>=0.13.1",
|
||||||
|
|||||||
@@ -90,46 +90,37 @@ def replace_llama_attn_with_flash_attn(
|
|||||||
llama_model_forward
|
llama_model_forward
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# skip only if explicitly disabled
|
||||||
if cross_entropy:
|
if cross_entropy:
|
||||||
patch_cross_entropy()
|
|
||||||
|
|
||||||
if rms_norm:
|
|
||||||
patch_rms_norm()
|
|
||||||
|
|
||||||
|
|
||||||
def patch_cross_entropy():
|
|
||||||
try:
|
|
||||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
|
||||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
|
||||||
CrossEntropyLoss, inplace_backward=True
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
LOG.info(
|
|
||||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_rms_norm():
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.triton.rms_norm import RMSNorm
|
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||||
|
|
||||||
|
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||||
|
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||||
|
CrossEntropyLoss, inplace_backward=True
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
LOG.info(
|
||||||
|
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# skip only if explicitly disabled
|
||||||
|
if rms_norm:
|
||||||
|
try:
|
||||||
from flash_attn.ops.rms_norm import RMSNorm
|
from flash_attn.ops.rms_norm import RMSNorm
|
||||||
|
|
||||||
class LlamaRMSNorm(RMSNorm):
|
class LlamaRMSNorm(RMSNorm):
|
||||||
"""Patched LLamaRMSNorm"""
|
"""Patched LLamaRMSNorm"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
super().__init__(hidden_size, eps=eps)
|
super().__init__(hidden_size, eps=eps)
|
||||||
|
|
||||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||||
except ImportError:
|
except ImportError:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FusedAttention(LlamaAttention):
|
class FusedAttention(LlamaAttention):
|
||||||
|
|||||||
@@ -24,12 +24,6 @@ from transformers import ( # noqa: F401
|
|||||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
|
||||||
patch_cross_entropy as llama_patch_cross_entropy,
|
|
||||||
)
|
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
|
||||||
patch_rms_norm as llama_patch_rms_norm,
|
|
||||||
)
|
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.chat_templates import chat_templates
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
@@ -287,7 +281,15 @@ def load_model(
|
|||||||
replace_llama_attn_with_flash_attn,
|
replace_llama_attn_with_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.s2_attention:
|
if cfg.sample_packing:
|
||||||
|
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||||
|
LOG.info("patching with flash attention for sample packing")
|
||||||
|
replace_llama_attn_with_flash_attn(
|
||||||
|
packed=True,
|
||||||
|
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||||
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
|
)
|
||||||
|
elif cfg.s2_attention:
|
||||||
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
LOG.info("patching w/ flash-enabled, shifted-sparse attention")
|
||||||
replace_llama_attn_with_flash_attn(
|
replace_llama_attn_with_flash_attn(
|
||||||
packed=False,
|
packed=False,
|
||||||
@@ -295,21 +297,6 @@ def load_model(
|
|||||||
rms_norm=cfg.flash_attn_rms_norm,
|
rms_norm=cfg.flash_attn_rms_norm,
|
||||||
use_shifted_sparse_attn=True,
|
use_shifted_sparse_attn=True,
|
||||||
)
|
)
|
||||||
elif cfg.device not in ["mps", "cpu"] and not inference:
|
|
||||||
if cfg.sample_packing:
|
|
||||||
LOG.info("patching with flash attention for sample packing")
|
|
||||||
replace_llama_attn_with_flash_attn(
|
|
||||||
packed=True,
|
|
||||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
|
||||||
rms_norm=cfg.flash_attn_rms_norm,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if cfg.flash_attn_cross_entropy:
|
|
||||||
llama_patch_cross_entropy()
|
|
||||||
|
|
||||||
if cfg.flash_attn_rms_norm:
|
|
||||||
llama_patch_rms_norm()
|
|
||||||
|
|
||||||
elif cfg.xformers_attention:
|
elif cfg.xformers_attention:
|
||||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||||
hijack_llama_attention,
|
hijack_llama_attention,
|
||||||
|
|||||||
@@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase):
|
|||||||
# Check that the output matches the expected output
|
# Check that the output matches the expected output
|
||||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
def test_output_multipack(self):
|
||||||
|
mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]])
|
||||||
|
dtype = torch.float32
|
||||||
|
expected_output = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Check that the output matches the expected output
|
||||||
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user