42 lines
1.1 KiB
Python
42 lines
1.1 KiB
Python
"""
|
|
utils to get GPU info for the current environment
|
|
"""
|
|
|
|
from importlib.metadata import version
|
|
|
|
from accelerate.utils.environment import (
|
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
|
)
|
|
from accelerate.utils.environment import (
|
|
get_gpu_info,
|
|
)
|
|
from packaging.version import Version, parse
|
|
|
|
|
|
def check_cuda_p2p_ib_support():
|
|
if not accelerate_check_cuda_p2p_ib_support():
|
|
return False
|
|
unsupported_devices = {"RTX 6000 Ada", "L40S"}
|
|
try:
|
|
device_names, device_count = get_gpu_info()
|
|
if 1 < device_count < 8:
|
|
if any(
|
|
unsupported_device in device_name
|
|
for device_name in device_names
|
|
for unsupported_device in unsupported_devices
|
|
):
|
|
return False
|
|
except Exception: # pylint: disable=broad-except # nosec
|
|
pass
|
|
return True
|
|
|
|
|
|
def get_package_version(package: str) -> Version:
|
|
version_str = version(package)
|
|
return parse(version_str)
|
|
|
|
|
|
def is_package_version_ge(package: str, version_: str) -> bool:
|
|
package_version = get_package_version(package)
|
|
return package_version >= parse(version_)
|