support for true batches with multipack (#1230)

* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
This commit is contained in:
Wing Lian
2024-02-01 10:18:42 -05:00
committed by GitHub
parent c67fb71583
commit 00568c1539
24 changed files with 573 additions and 246 deletions

View File

@@ -4,7 +4,9 @@ helper utils for tests
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
@@ -31,3 +33,15 @@ def most_recent_subdir(path):
subdir = max(subdirectories, key=os.path.getctime)
return subdir
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)