improve setting of NCCL_P2P_DISABLE on runpod (#3132) [skip ci]

* improve setting of NCCL_P2P_DISABLE on runpod

* use recs from review
This commit is contained in:
Wing Lian
2025-09-16 14:52:45 -04:00
committed by GitHub
parent 1ef6c196f7
commit d4cff1b7bb

View File

@@ -2,6 +2,8 @@
utils to get GPU info for the current environment
"""
import os
import subprocess # nosec B404
from importlib.metadata import version
from accelerate.utils.environment import (
@@ -14,6 +16,8 @@ from packaging.version import Version, parse
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
if not check_runpod_p2p_support():
return False
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
@@ -29,6 +33,39 @@ def check_cuda_p2p_ib_support():
return True
def check_runpod_p2p_support() -> bool:
if "RUNPOD_GPU_COUNT" not in os.environ:
return True
try:
gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1"))
except ValueError:
return True
if gpu_count >= 2:
# run `nvidia-smi topo -p2p n` and inspect the GPU0 row
try:
result = subprocess.run( # nosec B603 B607
["nvidia-smi", "topo", "-p2p", "n"],
check=True,
capture_output=True,
text=True,
timeout=5,
)
except (
subprocess.CalledProcessError,
FileNotFoundError,
subprocess.TimeoutExpired,
):
return True # fail-open if detection fails
output_lines = result.stdout.strip().split("\n")
# filter rows that start with "GPU0" (avoid header row)
gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")]
if not gpu0_rows:
return True
# consider P2P supported if any OK is present in the GPU0 row
return "OK" in gpu0_rows[-1]
return True
def get_package_version(package: str) -> Version:
version_str = version(package)
return parse(version_str)