Compare commits
1 Commits
diffusion-
...
sdpa-multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1a538be9c2 |
@@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase):
|
|||||||
# Check that the output matches the expected output
|
# Check that the output matches the expected output
|
||||||
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
def test_output_multipack(self):
|
||||||
|
mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]])
|
||||||
|
dtype = torch.float32
|
||||||
|
expected_output = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Check that the output matches the expected output
|
||||||
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user