From d4cff1b7bbd43d546d95b31943cf2810e30efe8f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 16 Sep 2025 14:52:45 -0400 Subject: [PATCH] improve setting of NCCL_P2P_DISABLE on runpod (#3132) [skip ci] * improve setting of NCCL_P2P_DISABLE on runpod * use recs from review --- src/axolotl/utils/environment.py | 37 ++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 192aca4e1..7b2348413 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -2,6 +2,8 @@ utils to get GPU info for the current environment """ +import os +import subprocess # nosec B404 from importlib.metadata import version from accelerate.utils.environment import ( @@ -14,6 +16,8 @@ from packaging.version import Version, parse def check_cuda_p2p_ib_support(): if not accelerate_check_cuda_p2p_ib_support(): return False + if not check_runpod_p2p_support(): + return False unsupported_devices = {"RTX 6000 Ada", "L40S"} try: device_names, device_count = get_gpu_info() @@ -29,6 +33,39 @@ def check_cuda_p2p_ib_support(): return True +def check_runpod_p2p_support() -> bool: + if "RUNPOD_GPU_COUNT" not in os.environ: + return True + try: + gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1")) + except ValueError: + return True + if gpu_count >= 2: + # run `nvidia-smi topo -p2p n` and inspect the GPU0 row + try: + result = subprocess.run( # nosec B603 B607 + ["nvidia-smi", "topo", "-p2p", "n"], + check=True, + capture_output=True, + text=True, + timeout=5, + ) + except ( + subprocess.CalledProcessError, + FileNotFoundError, + subprocess.TimeoutExpired, + ): + return True # fail-open if detection fails + output_lines = result.stdout.strip().split("\n") + # filter rows that start with "GPU0" (avoid header row) + gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")] + if not gpu0_rows: + return True + # consider P2P supported if any OK is present in the GPU0 row + return "OK" in gpu0_rows[-1] + return True + + def get_package_version(package: str) -> Version: version_str = version(package) return parse(version_str)