support for true batches with multipack (#1230)
* support for true batches with multipack * patch the map dataset fetcher to handle batches with packed indexes * patch 4d mask creation for sdp attention * better handling for BetterTransformer * patch general case for 4d mask * setup forward patch. WIP * fix patch file * support for multipack w/o flash attention for llama * cleanup * add warning about bf16 vs fp16 for multipack with sdpa * bugfixes * add 4d multipack tests, refactor patches * update tests and add warnings * fix e2e file check * skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
@@ -11,7 +11,7 @@ from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Se
|
||||
from axolotl.utils.data import encode_packed_pretraining
|
||||
|
||||
|
||||
class TestPacking(unittest.TestCase):
|
||||
class TestPretrainingPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user