disable ray tests for latest torch release (#2328)
* disable ray tests for latest torch release * move decorator from class to method
This commit is contained in:
@@ -9,7 +9,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
import yaml
|
||||
from accelerate.test_utils import execute_subprocess_async
|
||||
from e2e.utils import check_tensorboard
|
||||
from e2e.utils import check_tensorboard, require_torch_lt_2_6_0
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
@@ -24,6 +24,7 @@ class TestMultiGPURay:
|
||||
Test cases for AnyScale Ray post training
|
||||
"""
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -80,6 +81,7 @@ class TestMultiGPURay:
|
||||
temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
|
||||
)
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
@pytest.mark.parametrize(
|
||||
"gradient_accumulation_steps",
|
||||
[1, 2],
|
||||
|
||||
@@ -66,6 +66,18 @@ def require_torch_2_5_1(test_case):
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_lt_2_6_0(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
"""
|
||||
|
||||
def is_max_2_6_0():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version < version.parse("2.6.0")
|
||||
|
||||
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
Reference in New Issue
Block a user