From 3a5c97e6e5899cbeb7ea284c658712217bf9721c Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 9 Oct 2025 14:17:31 -0400 Subject: [PATCH] use can_device_access_peer for P2P checks (#3209) [skip ci] * use can_device_access_peer for P2P checks * also log warn when automatically setting NCCL_P2P_DISABLE=1 --- src/axolotl/utils/environment.py | 60 +++++++++++--------------------- src/axolotl/utils/trainer.py | 1 + 2 files changed, 21 insertions(+), 40 deletions(-) diff --git a/src/axolotl/utils/environment.py b/src/axolotl/utils/environment.py index 7b2348413..d5f2d9f78 100644 --- a/src/axolotl/utils/environment.py +++ b/src/axolotl/utils/environment.py @@ -3,66 +3,46 @@ utils to get GPU info for the current environment """ import os -import subprocess # nosec B404 from importlib.metadata import version +import torch from accelerate.utils.environment import ( check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support, - get_gpu_info, ) from packaging.version import Version, parse +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + def check_cuda_p2p_ib_support(): if not accelerate_check_cuda_p2p_ib_support(): return False - if not check_runpod_p2p_support(): + if not check_cuda_p2p_support(): return False - unsupported_devices = {"RTX 6000 Ada", "L40S"} - try: - device_names, device_count = get_gpu_info() - if 1 < device_count < 8: - if any( - unsupported_device in device_name - for device_name in device_names - for unsupported_device in unsupported_devices - ): - return False - except Exception: # nosec B110 - pass return True -def check_runpod_p2p_support() -> bool: - if "RUNPOD_GPU_COUNT" not in os.environ: - return True +def check_cuda_p2p_support() -> bool: try: - gpu_count = int(os.environ.get("RUNPOD_GPU_COUNT", "1")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) except ValueError: return True - if gpu_count >= 2: - # run `nvidia-smi topo -p2p n` and inspect the GPU0 row + + if world_size > 1: + node_world_size = int(os.environ.get("NODE_WORLD_SIZE", "8")) + local_other_rank = (local_rank // node_world_size) * node_world_size + local_other_rank += 1 if (local_rank % node_world_size) == 0 else 0 try: - result = subprocess.run( # nosec B603 B607 - ["nvidia-smi", "topo", "-p2p", "n"], - check=True, - capture_output=True, - text=True, - timeout=5, - ) - except ( - subprocess.CalledProcessError, - FileNotFoundError, - subprocess.TimeoutExpired, - ): - return True # fail-open if detection fails - output_lines = result.stdout.strip().split("\n") - # filter rows that start with "GPU0" (avoid header row) - gpu0_rows = [line for line in output_lines if line.lstrip().startswith("GPU0")] - if not gpu0_rows: + can_p2p = torch.cuda.can_device_access_peer(local_rank, local_other_rank) + except AssertionError as exc: + # some sort of logic error in indexing processes, assume p2p is fine for now + LOG.warning(exc) return True - # consider P2P supported if any OK is present in the GPU0 row - return "OK" in gpu0_rows[-1] + return can_p2p + return True diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index c7fa0a647..f2f8279f3 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -641,6 +641,7 @@ def setup_parallelism_envs(cfg): def prepare_optim_env(cfg): if not check_cuda_p2p_ib_support(): if os.getenv("NCCL_P2P_DISABLE") is None: + LOG.warning("P2P support not detected, setting `NCCL_P2P_DISABLE=1`") os.environ["NCCL_P2P_DISABLE"] = "1" # TODO @SalmanMohammadi remove the cfg.fsdp check in 0.12 if cfg.fsdp or cfg.fsdp_config: