* support for true batches with multipack * patch the map dataset fetcher to handle batches with packed indexes * patch 4d mask creation for sdp attention * better handling for BetterTransformer * patch general case for 4d mask * setup forward patch. WIP * fix patch file * support for multipack w/o flash attention for llama * cleanup * add warning about bf16 vs fp16 for multipack with sdpa * bugfixes * add 4d multipack tests, refactor patches * update tests and add warnings * fix e2e file check * skip sdpa test if not at least torch 2.1.1, update docs
114 lines
3.6 KiB
Python
114 lines
3.6 KiB
Python
"""
|
|
Unit tests for the monkeypatch utils
|
|
"""
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from axolotl.monkeypatch.utils import (
|
|
get_cu_seqlens,
|
|
get_cu_seqlens_from_pos_ids,
|
|
get_max_seqlen_in_batch,
|
|
get_unpad_data,
|
|
)
|
|
|
|
|
|
class TestMonkeyPatchUtils(unittest.TestCase):
|
|
"""
|
|
Unit test class for monkeypatch utils
|
|
"""
|
|
|
|
def test_get_cu_seqlens_1d(self):
|
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
|
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
|
|
|
|
def test_get_cu_seqlens_from_pos_ids_1d(self):
|
|
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
|
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
|
self.assertTrue(
|
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
|
)
|
|
|
|
def test_get_cu_seqlens_from_pos_ids_2d(self):
|
|
position_ids = torch.tensor(
|
|
[
|
|
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
|
|
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
|
|
]
|
|
)
|
|
target_res = torch.tensor(
|
|
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
|
|
)
|
|
self.assertTrue(
|
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
|
)
|
|
|
|
def test_get_max_seqlen_in_batch(self):
|
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
|
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
|
self.assertTrue(torch.allclose(get_max_seqlen_in_batch(attn_mask), target_res))
|
|
|
|
def test_get_unpad_data(self):
|
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
|
target_indices = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13])
|
|
target_cu_seqlen = torch.tensor([0, 4, 7, 12, 14], dtype=torch.int32)
|
|
target_max_seqlen_in_batch = 5
|
|
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
|
self.assertTrue(torch.allclose(target_indices, indices))
|
|
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
|
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
|
|
|
attn_mask = torch.tensor(
|
|
[
|
|
[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0],
|
|
[1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5],
|
|
]
|
|
)
|
|
target_indices = torch.tensor(
|
|
[
|
|
0,
|
|
1,
|
|
2,
|
|
3,
|
|
4,
|
|
5,
|
|
6,
|
|
7,
|
|
8,
|
|
9,
|
|
10,
|
|
11,
|
|
12,
|
|
13,
|
|
16,
|
|
17,
|
|
18,
|
|
19,
|
|
20,
|
|
21,
|
|
22,
|
|
23,
|
|
24,
|
|
25,
|
|
26,
|
|
27,
|
|
28,
|
|
29,
|
|
30,
|
|
31,
|
|
]
|
|
)
|
|
target_cu_seqlen = torch.tensor(
|
|
[0, 4, 7, 12, 14, 17, 22, 24, 27, 30], dtype=torch.int32
|
|
)
|
|
target_max_seqlen_in_batch = 5
|
|
indices, cu_seqlen, max_seqlen_in_batch = get_unpad_data(attn_mask)
|
|
self.assertTrue(torch.allclose(target_indices, indices))
|
|
self.assertTrue(torch.allclose(target_cu_seqlen, cu_seqlen))
|
|
self.assertEqual(target_max_seqlen_in_batch, max_seqlen_in_batch)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|