77 lines
2.0 KiB
Plaintext
77 lines
2.0 KiB
Plaintext
---
|
|
title: Multipack (Sample Packing)
|
|
description: Multipack is a technique to pack multiple sequences into a single batch to increase training throughput.
|
|
---
|
|
|
|
## Visualization of Multipack with Flash Attention
|
|
|
|
Because Flash Attention simply drops the attention mask, we do not need to
|
|
construct a 4d attention mask. We only need to concatenate the sequences into
|
|
a single batch and let flash attention know where each new sequence begins.
|
|
|
|
|
|
4k context, bsz =4,
|
|
each character represents 256 tokens
|
|
X represents a padding token
|
|
|
|
```
|
|
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
|
[[ A A A A A A A A A A A ]
|
|
B B B B B B ]
|
|
C C C C C C C ]
|
|
D D D D ]]
|
|
|
|
[[ E E E E E E E E ]
|
|
[ F F F F ]
|
|
[ G G G ]
|
|
[ H H H H ]]
|
|
|
|
[[ I I I ]
|
|
[ J J J ]
|
|
[ K K K K K]
|
|
[ L L L ]]
|
|
```
|
|
|
|
after padding to longest input in each step
|
|
```
|
|
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
|
[[ A A A A A A A A A A A ]
|
|
B B B B B B X X X X X X ]
|
|
C C C C C C C X X X X ]
|
|
D D D D X X X X X X X ]]
|
|
|
|
[[ E E E E E E E E ]
|
|
[ F F F F X X X X ]
|
|
[ G G G X X X X X ]
|
|
[ H H H H X X X X ]]
|
|
|
|
[[ I I I X X ]
|
|
[ J J J X X ]
|
|
[ K K K K K ]
|
|
[ L L L X X ]]
|
|
```
|
|
|
|
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
|
```
|
|
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
|
[[ A A A A A A A A A A A B B B B B
|
|
B C C C C C C C D D D D E E E E
|
|
E E E E F F F F F G G G H H H H
|
|
I I I J J J J K K K K K L L L X ]]
|
|
```
|
|
|
|
cu_seqlens:
|
|
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
|
|
|
|
|
|
## Multipack without Flash Attention
|
|
|
|
Multipack can still be achieved without Flash attention, but with lower packing
|
|
efficiency as we are not able to join multiple batches into a single batch due to
|
|
context length limits without flash attention. We can use either Pytorch's Scaled
|
|
Dot Product Attention implementation or native Pytorch attention implementation
|
|
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
|
|
to pack sequences together and avoid cross attention.
|
|
|
|
<img src="./images/4d-mask.png" alt="axolotl" width="800">
|