calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier

This commit is contained in:
Wing Lian
2023-08-10 17:16:01 -04:00
parent 57d9bf711c
commit a07f432d9c
4 changed files with 17 additions and 6 deletions

View File

@@ -36,7 +36,6 @@ class TestExpandMask(unittest.TestCase):
],
]
)
print(repr(_expand_mask(mask, dtype)))
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))