support for true batches with multipack (#1230)

* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
Wing Lian
2024-02-01 10:18:42 -05:00
committed by GitHub
parent c67fb71583
commit 00568c1539
24 changed files with 573 additions and 246 deletions

View File

@@ -11,7 +11,7 @@ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Se
from axolotl.utils.data import encode_packed_pretraining
class TestPacking(unittest.TestCase):
class TestPretrainingPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""