diff --git a/tests/e2e/multigpu/test_ray.py b/tests/e2e/multigpu/test_ray.py index d7e4ddfcf..72ec69aa8 100644 --- a/tests/e2e/multigpu/test_ray.py +++ b/tests/e2e/multigpu/test_ray.py @@ -9,7 +9,7 @@ from pathlib import Path import pytest import yaml from accelerate.test_utils import execute_subprocess_async -from e2e.utils import check_tensorboard +from e2e.utils import check_tensorboard, require_torch_lt_2_6_0 from axolotl.utils.dict import DictDefault @@ -24,6 +24,7 @@ class TestMultiGPURay: Test cases for AnyScale Ray post training """ + @require_torch_lt_2_6_0 def test_lora_ddp(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault( @@ -80,6 +81,7 @@ class TestMultiGPURay: temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high" ) + @require_torch_lt_2_6_0 @pytest.mark.parametrize( "gradient_accumulation_steps", [1, 2], diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index 2baead7d2..a9f7fb28d 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -66,6 +66,18 @@ def require_torch_2_5_1(test_case): return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case) +def require_torch_lt_2_6_0(test_case): + """ + Decorator marking a test that requires torch >= 2.5.1 + """ + + def is_max_2_6_0(): + torch_version = version.parse(torch.__version__) + return torch_version < version.parse("2.6.0") + + return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case) + + def is_hopper(): compute_capability = torch.cuda.get_device_capability() return compute_capability == (9, 0)