fix flash-attn, xformers, packing, support chatml

This commit is contained in:
Wing Lian
2023-08-04 10:09:16 -04:00
parent 0b01da0713
commit f93f0017cd
7 changed files with 56 additions and 13 deletions

View File

@@ -20,8 +20,8 @@ class TestExpandMask(unittest.TestCase):
[
[
[
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
]
@@ -29,14 +29,14 @@ class TestExpandMask(unittest.TestCase):
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
]
],
]
)
print(repr(_expand_mask(mask, dtype)))
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))