Flex Attention + Packing with BlockMask support (#2363)
This commit is contained in:
@@ -67,9 +67,21 @@ def require_torch_2_5_1(test_case):
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.6.0
|
||||
"""
|
||||
|
||||
def is_min_2_6_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.6.0")
|
||||
|
||||
return unittest.skipUnless(is_min_2_6_0(), "test requires torch>=2.6.0")(test_case)
|
||||
|
||||
|
||||
def require_torch_lt_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
Decorator marking a test that requires torch < 2.6.0
|
||||
"""
|
||||
|
||||
def is_max_2_6_0():
|
||||
|
||||
Reference in New Issue
Block a user