From 1a538be9c26581f1de22672b41fca3ecdb7dc6a4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 26 Jan 2024 00:41:24 -0500 Subject: [PATCH] add a prelim test for expading the 4d mask --- tests/test_expand_mask.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_expand_mask.py b/tests/test_expand_mask.py index 01241c295..24c743ece 100644 --- a/tests/test_expand_mask.py +++ b/tests/test_expand_mask.py @@ -39,6 +39,32 @@ class TestExpandMask(unittest.TestCase): # Check that the output matches the expected output self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) + def test_output_multipack(self): + mask = torch.tensor([[1, 1, 1, 0], [2, 2, 3, 3]]) + dtype = torch.float32 + expected_output = torch.tensor( + [ + [ + [ + [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], + [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38], + [0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38], + [-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38], + ] + ], + [ + [ + [0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38], + [0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38], + [-3.4028e38, -3.4028e38, 0.0000e00, -3.4028e38], + [-3.4028e38, -3.4028e38, 0.0000e00, 0.0000e00], + ] + ], + ] + ) + # Check that the output matches the expected output + self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output)) + if __name__ == "__main__": unittest.main()