disable ray tests for latest torch release (#2328)

* disable ray tests for latest torch release

* move decorator from class to method
This commit is contained in:
Wing Lian
2025-02-12 18:29:02 -05:00
committed by GitHub
parent e37a4a536a
commit 30046315d9
2 changed files with 15 additions and 1 deletions

View File

@@ -9,7 +9,7 @@ from pathlib import Path
import pytest
import yaml
from accelerate.test_utils import execute_subprocess_async
from e2e.utils import check_tensorboard
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
from axolotl.utils.dict import DictDefault
@@ -24,6 +24,7 @@ class TestMultiGPURay:
Test cases for AnyScale Ray post training
"""
@require_torch_lt_2_6_0
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -80,6 +81,7 @@ class TestMultiGPURay:
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
)
@require_torch_lt_2_6_0
@pytest.mark.parametrize(
"gradient_accumulation_steps",
[1, 2],

View File

@@ -66,6 +66,18 @@ def require_torch_2_5_1(test_case):
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
def require_torch_lt_2_6_0(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_max_2_6_0():
torch_version = version.parse(torch.__version__)
return torch_version < version.parse("2.6.0")
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)