removing 2.3.1 (#2294)
This commit is contained in:
@@ -12,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -23,7 +23,6 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_3_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -42,18 +42,6 @@ def most_recent_subdir(path):
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_3_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.3.1
|
||||
"""
|
||||
|
||||
def is_min_2_3_1():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.3.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_4_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
|
||||
Reference in New Issue
Block a user