* adding pre-commit auto-update GH action and bumping plugin versions * running updated pre-commit plugins * sorry to revert, but pylint complained * Update .pre-commit-config.yaml Co-authored-by: Wing Lian <wing.lian@gmail.com> --------- Co-authored-by: Dan Saunders <dan@axolotl.ai> Co-authored-by: Wing Lian <wing.lian@gmail.com>
46 lines
1.4 KiB
Python
46 lines
1.4 KiB
Python
"""
|
|
Unit tests for the monkey patch for expand mask to handle packed sequences
|
|
"""
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
|
|
|
|
|
class TestExpandMask(unittest.TestCase):
|
|
"""
|
|
Test class for attention mask expansion for packed sequences
|
|
"""
|
|
|
|
def test_output(self):
|
|
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
|
dtype = torch.float32
|
|
expected_output = torch.tensor(
|
|
[
|
|
[
|
|
[
|
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
|
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
|
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
|
]
|
|
],
|
|
[
|
|
[
|
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
|
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
|
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
|
]
|
|
],
|
|
]
|
|
)
|
|
# Check that the output matches the expected output
|
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|