support for true batches with multipack (#1230)
* support for true batches with multipack * patch the map dataset fetcher to handle batches with packed indexes * patch 4d mask creation for sdp attention * better handling for BetterTransformer * patch general case for 4d mask * setup forward patch. WIP * fix patch file * support for multipack w/o flash attention for llama * cleanup * add warning about bf16 vs fp16 for multipack with sdpa * bugfixes * add 4d multipack tests, refactor patches * update tests and add warnings * fix e2e file check * skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
@@ -30,6 +30,20 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_cu_seqlens_from_pos_ids_2d(self):
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
|
||||
]
|
||||
)
|
||||
target_res = torch.tensor(
|
||||
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_max_seqlen_in_batch(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
||||
|
||||
Reference in New Issue
Block a user