Files
axolotl/tests/e2e/utils.py
Wing Lian 00568c1539 support for true batches with multipack (#1230)
* support for true batches with multipack

* patch the map dataset fetcher to handle batches with packed indexes

* patch 4d mask creation for sdp attention

* better handling for BetterTransformer

* patch general case for 4d mask

* setup forward patch. WIP

* fix patch file

* support for multipack w/o flash attention for llama

* cleanup

* add warning about bf16 vs fp16 for multipack with sdpa

* bugfixes

* add 4d multipack tests, refactor patches

* update tests and add warnings

* fix e2e file check

* skip sdpa test if not at least torch 2.1.1, update docs
2024-02-01 10:18:42 -05:00

48 lines
1.1 KiB
Python

"""
helper utils for tests
"""
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
def with_temp_dir(test_func):
@wraps(test_func)
def wrapper(*args, **kwargs):
# Create a temporary directory
temp_dir = tempfile.mkdtemp()
try:
# Pass the temporary directory to the test function
test_func(*args, temp_dir=temp_dir, **kwargs)
finally:
# Clean up the directory after the test
shutil.rmtree(temp_dir)
return wrapper
def most_recent_subdir(path):
base_path = Path(path)
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
if not subdirectories:
return None
subdir = max(subdirectories, key=os.path.getctime)
return subdir
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)