Multipack simplify for Mixtral (#1142)
This commit is contained in:
@@ -5,7 +5,12 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.utils import (
|
||||
get_cu_seqlens,
|
||||
get_cu_seqlens_from_pos_ids,
|
||||
get_max_seqlen_in_batch,
|
||||
get_unpad_data,
|
||||
)
|
||||
|
||||
|
||||
class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
@@ -25,6 +30,70 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_max_seqlen_in_batch(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
||||
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
|
||||
|
||||
def test_get_unpad_data(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
||||
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
|
||||
target_max_seqlen_in_batch = 5
|
||||
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
||||
self.assertTrue(torch.allclose(target_indices, indices))
|
||||
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
||||
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
||||
|
||||
attn_mask = torch.tensor(
|
||||
[
|
||||
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
|
||||
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
|
||||
]
|
||||
)
|
||||
target_indices = torch.tensor(
|
||||
[
|
||||
0,
|
||||
1,
|
||||
2,
|
||||
3,
|
||||
4,
|
||||
5,
|
||||
6,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
10,
|
||||
11,
|
||||
12,
|
||||
13,
|
||||
16,
|
||||
17,
|
||||
18,
|
||||
19,
|
||||
20,
|
||||
21,
|
||||
22,
|
||||
23,
|
||||
24,
|
||||
25,
|
||||
26,
|
||||
27,
|
||||
28,
|
||||
29,
|
||||
30,
|
||||
31,
|
||||
]
|
||||
)
|
||||
target_cu_seqlen = torch.tensor(
|
||||
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
|
||||
)
|
||||
target_max_seqlen_in_batch = 5
|
||||
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
||||
self.assertTrue(torch.allclose(target_indices, indices))
|
||||
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
||||
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user