cleanup: remove dead SDPA patches (#3488) [skip ci]
Transformers 5.x routes attention through sdpa_attention.py and no longer calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that these patches targeted. This makes the following patches dead code: - llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*) - llama_expand_mask.py (patched _expand_mask, never called) - Related utility functions in monkeypatch/utils.py Closes axolotl-ai-cloud/axolotl#3331
This commit is contained in:
@@ -128,11 +128,9 @@ quartodoc:
|
|||||||
- monkeypatch.mistral_attn_hijack_flash
|
- monkeypatch.mistral_attn_hijack_flash
|
||||||
- monkeypatch.multipack
|
- monkeypatch.multipack
|
||||||
- monkeypatch.relora
|
- monkeypatch.relora
|
||||||
- monkeypatch.llama_expand_mask
|
|
||||||
- monkeypatch.lora_kernels
|
- monkeypatch.lora_kernels
|
||||||
- monkeypatch.utils
|
- monkeypatch.utils
|
||||||
- monkeypatch.btlm_attn_hijack_flash
|
- monkeypatch.btlm_attn_hijack_flash
|
||||||
- monkeypatch.llama_patch_multipack
|
|
||||||
- monkeypatch.stablelm_attn_hijack_flash
|
- monkeypatch.stablelm_attn_hijack_flash
|
||||||
- monkeypatch.trainer_fsdp_optim
|
- monkeypatch.trainer_fsdp_optim
|
||||||
- monkeypatch.transformers_fa_utils
|
- monkeypatch.transformers_fa_utils
|
||||||
|
|||||||
@@ -571,15 +571,6 @@ class PatchManager:
|
|||||||
LOG.info("Patching with xformers attention...")
|
LOG.info("Patching with xformers attention...")
|
||||||
hijack_llama_attention()
|
hijack_llama_attention()
|
||||||
|
|
||||||
def _patch_llama_sample_packing(self):
|
|
||||||
"""Apply sample packing patches for LLaMA models."""
|
|
||||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
|
||||||
hijack_llama_prepare_4d_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("Patching llama _prepare_4d_causal_attention_mask*...")
|
|
||||||
hijack_llama_prepare_4d_mask()
|
|
||||||
|
|
||||||
def _patch_llama_derived_model(self):
|
def _patch_llama_derived_model(self):
|
||||||
"""Modify all llama derived models in one block."""
|
"""Modify all llama derived models in one block."""
|
||||||
if self.cfg.is_llama_derived_model and not (
|
if self.cfg.is_llama_derived_model and not (
|
||||||
@@ -591,8 +582,6 @@ class PatchManager:
|
|||||||
self._patch_llama_flash_attention()
|
self._patch_llama_flash_attention()
|
||||||
elif self.cfg.xformers_attention:
|
elif self.cfg.xformers_attention:
|
||||||
self._patch_llama_xformers_attention()
|
self._patch_llama_xformers_attention()
|
||||||
elif self.cfg.sample_packing:
|
|
||||||
self._patch_llama_sample_packing()
|
|
||||||
elif self.cfg.s2_attention:
|
elif self.cfg.s2_attention:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Shifted-sparse attention not currently implemented without flash attention."
|
"Shifted-sparse attention not currently implemented without flash attention."
|
||||||
|
|||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
|
||||||
|
|
||||||
|
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
|
||||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
|
||||||
inverted_mask = 1.0 - masked_zero_one_mask
|
|
||||||
|
|
||||||
return inverted_mask.masked_fill(
|
|
||||||
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hijack_expand_mask():
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
transformers.models.llama.modeling_llama._expand_mask = _expand_mask
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import (
|
|
||||||
patched_prepare_4d_causal_attention_mask,
|
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def hijack_llama_prepare_4d_mask():
|
|
||||||
from transformers import modeling_attn_mask_utils
|
|
||||||
from transformers.models.llama import modeling_llama
|
|
||||||
|
|
||||||
modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = (
|
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
|
||||||
)
|
|
||||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = (
|
|
||||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
|
||||||
)
|
|
||||||
modeling_llama._prepare_4d_causal_attention_mask = (
|
|
||||||
patched_prepare_4d_causal_attention_mask
|
|
||||||
)
|
|
||||||
modeling_attn_mask_utils._prepare_4d_causal_attention_mask = (
|
|
||||||
patched_prepare_4d_causal_attention_mask
|
|
||||||
)
|
|
||||||
@@ -3,15 +3,10 @@ Shared utils for the monkeypatches
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.modeling_attn_mask_utils import (
|
|
||||||
_prepare_4d_causal_attention_mask,
|
|
||||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
|
||||||
)
|
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
@torch.jit.script
|
||||||
@@ -170,65 +165,6 @@ def set_module_name(model, name, value):
|
|||||||
setattr(parent, child_name, value)
|
setattr(parent, child_name, value)
|
||||||
|
|
||||||
|
|
||||||
def mask_2d_to_4d(
|
|
||||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
|
||||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
|
||||||
when they attend to each other within that sequence.
|
|
||||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
|
||||||
"""
|
|
||||||
bsz, src_len = mask.size()
|
|
||||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
|
||||||
|
|
||||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
|
||||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
|
||||||
|
|
||||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
|
||||||
binary_mask = torch.where(
|
|
||||||
mask != 0,
|
|
||||||
torch.tensor(1, device=mask.device).to(dtype),
|
|
||||||
torch.tensor(0, device=mask.device).to(dtype),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a block-diagonal mask.
|
|
||||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
|
||||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
|
||||||
|
|
||||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
|
||||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
|
||||||
mask.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
|
||||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
|
||||||
|
|
||||||
return masked_zero_one_mask
|
|
||||||
|
|
||||||
|
|
||||||
def patched_prepare_4d_causal_attention_mask(
|
|
||||||
attention_mask: Optional[torch.Tensor],
|
|
||||||
*args,
|
|
||||||
):
|
|
||||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
||||||
return _prepare_4d_causal_attention_mask(
|
|
||||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
attention_mask: Optional[torch.Tensor],
|
|
||||||
*args,
|
|
||||||
):
|
|
||||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
|
||||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
|
||||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
|
||||||
*args,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detab_code(code: str) -> Tuple[str, str]:
|
def detab_code(code: str) -> Tuple[str, str]:
|
||||||
try:
|
try:
|
||||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||||
|
|||||||
@@ -1,45 +0,0 @@
|
|||||||
"""
|
|
||||||
Unit tests for the monkey patch for expand mask to handle packed sequences
|
|
||||||
"""
|
|
||||||
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
|
||||||
|
|
||||||
|
|
||||||
class TestExpandMask(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test class for attention mask expansion for packed sequences
|
|
||||||
"""
|
|
||||||
|
|
||||||
def test_output(self):
|
|
||||||
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
|
||||||
dtype = torch.float32
|
|
||||||
expected_output = torch.tensor(
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
||||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
|
||||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
|
||||||
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
|
||||||
]
|
|
||||||
],
|
|
||||||
[
|
|
||||||
[
|
|
||||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
||||||
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
|
||||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
|
||||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
||||||
]
|
|
||||||
],
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Check that the output matches the expected output
|
|
||||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
Reference in New Issue
Block a user