From 7ddfb2d8a0531122649a30a1b1d8bb40cd0e7191 Mon Sep 17 00:00:00 2001 From: Avaya Aggarwal <119044997+OnePunchMonk@users.noreply.github.com> Date: Fri, 20 Mar 2026 15:40:41 +0530 Subject: [PATCH] cleanup: remove dead SDPA patches (#3488) [skip ci] Transformers 5.x routes attention through sdpa_attention.py and no longer calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that these patches targeted. This makes the following patches dead code: - llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*) - llama_expand_mask.py (patched _expand_mask, never called) - Related utility functions in monkeypatch/utils.py Closes axolotl-ai-cloud/axolotl#3331 --- _quarto.yml | 2 - src/axolotl/loaders/patch_manager.py | 11 ---- src/axolotl/monkeypatch/llama_expand_mask.py | 24 ------- .../monkeypatch/llama_patch_multipack.py | 26 -------- src/axolotl/monkeypatch/utils.py | 66 +------------------ tests/test_expand_mask.py | 45 ------------- 6 files changed, 1 insertion(+), 173 deletions(-) delete mode 100644 src/axolotl/monkeypatch/llama_expand_mask.py delete mode 100644 src/axolotl/monkeypatch/llama_patch_multipack.py delete mode 100644 tests/test_expand_mask.py diff --git a/_quarto.yml b/_quarto.yml index 5e1169102..404125d1c 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -128,11 +128,9 @@ quartodoc: - monkeypatch.mistral_attn_hijack_flash - monkeypatch.multipack - monkeypatch.relora - - monkeypatch.llama_expand_mask - monkeypatch.lora_kernels - monkeypatch.utils - monkeypatch.btlm_attn_hijack_flash - - monkeypatch.llama_patch_multipack - monkeypatch.stablelm_attn_hijack_flash - monkeypatch.trainer_fsdp_optim - monkeypatch.transformers_fa_utils diff --git a/src/axolotl/loaders/patch_manager.py b/src/axolotl/loaders/patch_manager.py index bddd388e4..38cc198d3 100644 --- a/src/axolotl/loaders/patch_manager.py +++ b/src/axolotl/loaders/patch_manager.py @@ -571,15 +571,6 @@ class PatchManager: LOG.info("Patching with xformers attention...") hijack_llama_attention() - def _patch_llama_sample_packing(self): - """Apply sample packing patches for LLaMA models.""" - from axolotl.monkeypatch.llama_patch_multipack import ( - hijack_llama_prepare_4d_mask, - ) - - LOG.info("Patching llama _prepare_4d_causal_attention_mask*...") - hijack_llama_prepare_4d_mask() - def _patch_llama_derived_model(self): """Modify all llama derived models in one block.""" if self.cfg.is_llama_derived_model and not ( @@ -591,8 +582,6 @@ class PatchManager: self._patch_llama_flash_attention() elif self.cfg.xformers_attention: self._patch_llama_xformers_attention() - elif self.cfg.sample_packing: - self._patch_llama_sample_packing() elif self.cfg.s2_attention: raise NotImplementedError( "Shifted-sparse attention not currently implemented without flash attention." diff --git a/src/axolotl/monkeypatch/llama_expand_mask.py b/src/axolotl/monkeypatch/llama_expand_mask.py deleted file mode 100644 index 5cfb7818e..000000000 --- a/src/axolotl/monkeypatch/llama_expand_mask.py +++ /dev/null @@ -1,24 +0,0 @@ -""" -expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf -""" - -from typing import Optional - -import torch - -from axolotl.monkeypatch.utils import mask_2d_to_4d - - -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len) - inverted_mask = 1.0 - masked_zero_one_mask - - return inverted_mask.masked_fill( - inverted_mask.to(torch.bool), torch.finfo(dtype).min - ) - - -def hijack_expand_mask(): - import transformers - - transformers.models.llama.modeling_llama._expand_mask = _expand_mask diff --git a/src/axolotl/monkeypatch/llama_patch_multipack.py b/src/axolotl/monkeypatch/llama_patch_multipack.py deleted file mode 100644 index 8d234881f..000000000 --- a/src/axolotl/monkeypatch/llama_patch_multipack.py +++ /dev/null @@ -1,26 +0,0 @@ -""" -Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention -""" - -from axolotl.monkeypatch.utils import ( - patched_prepare_4d_causal_attention_mask, - patched_prepare_4d_causal_attention_mask_for_sdpa, -) - - -def hijack_llama_prepare_4d_mask(): - from transformers import modeling_attn_mask_utils - from transformers.models.llama import modeling_llama - - modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( - patched_prepare_4d_causal_attention_mask_for_sdpa - ) - modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( - patched_prepare_4d_causal_attention_mask_for_sdpa - ) - modeling_llama._prepare_4d_causal_attention_mask = ( - patched_prepare_4d_causal_attention_mask - ) - modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( - patched_prepare_4d_causal_attention_mask - ) diff --git a/src/axolotl/monkeypatch/utils.py b/src/axolotl/monkeypatch/utils.py index 4c6a4de11..3ec242ef0 100644 --- a/src/axolotl/monkeypatch/utils.py +++ b/src/axolotl/monkeypatch/utils.py @@ -3,15 +3,10 @@ Shared utils for the monkeypatches """ import re -from typing import Optional, Tuple +from typing import Tuple import torch import torch.nn.functional as F -from transformers.modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) -from transformers.utils import is_torch_bf16_gpu_available @torch.jit.script @@ -170,65 +165,6 @@ def set_module_name(model, name, value): setattr(parent, child_name, value) -def mask_2d_to_4d( - mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None -): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - This expansion handles packed sequences so that sequences share the same attention mask integer value - when they attend to each other within that sequence. - This expansion transforms the mask to lower triangular form to prevent future peeking. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len - - mask = mask.unsqueeze(1).unsqueeze(2) - mask = mask.expand(bsz, 1, tgt_len, src_len) - - # Create a binary mask from the original mask where zeros remain zeros and all other values are set to one - binary_mask = torch.where( - mask != 0, - torch.tensor(1, device=mask.device).to(dtype), - torch.tensor(0, device=mask.device).to(dtype), - ) - - # Create a block-diagonal mask. - # we multiply by the binary mask so that 0's in the original mask are correctly excluded - zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask - - # Now let's create a lower triangular mask of ones that will zero out the upper triangular part - lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to( - mask.device - ) - - # Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask - masked_zero_one_mask = zero_one_mask * lower_triangular_ones - - return masked_zero_one_mask - - -def patched_prepare_4d_causal_attention_mask( - attention_mask: Optional[torch.Tensor], - *args, -): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask( - mask_2d_to_4d(attention_mask, dtype=dtype), - *args, - ) - - -def patched_prepare_4d_causal_attention_mask_for_sdpa( - attention_mask: Optional[torch.Tensor], - *args, -): - dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32 - return _prepare_4d_causal_attention_mask_for_sdpa( - mask_2d_to_4d(attention_mask, dtype=dtype), - *args, - ) - - def detab_code(code: str) -> Tuple[str, str]: try: spaces = re.match(r"([\s\t]{1,})", code).group(0) diff --git a/tests/test_expand_mask.py b/tests/test_expand_mask.py deleted file mode 100644 index 1c69ca234..000000000 --- a/tests/test_expand_mask.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -Unit tests for the monkey patch for expand mask to handle packed sequences -""" - -import unittest - -import torch - -from axolotl.monkeypatch.llama_expand_mask import _expand_mask - - -class TestExpandMask(unittest.TestCase): - """ - Test class for attention mask expansion for packed sequences - """ - - def test_output(self): - mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]]) - dtype = torch.float32 - expected_output = torch.tensor( - [ - [ - [ - [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], - [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38], - [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], - [-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00], - ] - ], - [ - [ - [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], - [-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38], - [-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38], - [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38], - ] - ], - ] - ) - # Check that the output matches the expected output - self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) - - -if __name__ == "__main__": - unittest.main()