cleanup: remove dead SDPA patches (#3488) [skip ci]

Transformers 5.x routes attention through sdpa_attention.py and no longer
calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that
these patches targeted. This makes the following patches dead code:

- llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*)
- llama_expand_mask.py (patched _expand_mask, never called)
- Related utility functions in monkeypatch/utils.py

Closes axolotl-ai-cloud/axolotl#3331
This commit is contained in:
Avaya Aggarwal
2026-03-20 15:40:41 +05:30
committed by GitHub
parent c57acef2c7
commit 7ddfb2d8a0
6 changed files with 1 additions and 173 deletions

View File

@@ -1,45 +0,0 @@
"""
Unit tests for the monkey patch for expand mask to handle packed sequences
"""
import unittest
import torch
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
class TestExpandMask(unittest.TestCase):
"""
Test class for attention mask expansion for packed sequences
"""
def test_output(self):
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
dtype = torch.float32
expected_output = torch.tensor(
[
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
]
],
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
]
],
]
)
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
if __name__ == "__main__":
unittest.main()