Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
1a538be9c2 add a prelim test for expading the 4d mask 2024-01-26 00:41:24 -05:00

View File

@@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase):
# Check that the output matches the expected output # Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
def test_output_multipack(self):
mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]])
dtype = torch.float32
expected_output = torch.tensor(
[
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
]
],
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00],
]
],
]
)
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()