GRPO (#2307)
This commit is contained in:
@@ -78,6 +78,24 @@ def require_torch_lt_2_6_0(test_case):
|
||||
return unittest.skipUnless(is_max_2_6_0(), "test requires torch<2.6.0")(test_case)
|
||||
|
||||
|
||||
def require_vllm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a vllm to be installed
|
||||
"""
|
||||
|
||||
def is_vllm_installed():
|
||||
try:
|
||||
import vllm # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
return unittest.skipUnless(
|
||||
is_vllm_installed(), "test requires a vllm to be installed"
|
||||
)(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
return compute_capability == (9, 0)
|
||||
|
||||
Reference in New Issue
Block a user