cleanup: remove dead SDPA patches (#3488) [skip ci]
Transformers 5.x routes attention through sdpa_attention.py and no longer calls the _prepare_4d_causal_attention_mask* or _expand_mask functions that these patches targeted. This makes the following patches dead code: - llama_patch_multipack.py (patched _prepare_4d_causal_attention_mask*) - llama_expand_mask.py (patched _expand_mask, never called) - Related utility functions in monkeypatch/utils.py Closes axolotl-ai-cloud/axolotl#3331
This commit is contained in:
@@ -1,45 +0,0 @@
|
||||
"""
|
||||
Unit tests for the monkey patch for expand mask to handle packed sequences
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
||||
|
||||
|
||||
class TestExpandMask(unittest.TestCase):
|
||||
"""
|
||||
Test class for attention mask expansion for packed sequences
|
||||
"""
|
||||
|
||||
def test_output(self):
|
||||
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
||||
dtype = torch.float32
|
||||
expected_output = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
||||
]
|
||||
],
|
||||
[
|
||||
[
|
||||
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||
]
|
||||
],
|
||||
]
|
||||
)
|
||||
# Check that the output matches the expected output
|
||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user