add P2P env when multi-gpu but not the full node (#2041)
Co-authored-by: Wing Lian <wing@axolotl.ai>
This commit is contained in:
25
src/axolotl/utils/environment.py
Normal file
25
src/axolotl/utils/environment.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""
|
||||||
|
utils to get GPU info for the current environment
|
||||||
|
"""
|
||||||
|
from accelerate.utils.environment import (
|
||||||
|
check_cuda_p2p_ib_support as accelerate_check_cuda_p2p_ib_support,
|
||||||
|
)
|
||||||
|
from accelerate.utils.environment import get_gpu_info
|
||||||
|
|
||||||
|
|
||||||
|
def check_cuda_p2p_ib_support():
|
||||||
|
if not accelerate_check_cuda_p2p_ib_support():
|
||||||
|
return False
|
||||||
|
unsupported_devices = {"RTX 6000 Ada"}
|
||||||
|
try:
|
||||||
|
device_names, device_count = get_gpu_info()
|
||||||
|
if 1 < device_count < 8:
|
||||||
|
if any(
|
||||||
|
device_name in unsupported_device
|
||||||
|
for device_name in device_names
|
||||||
|
for unsupported_device in unsupported_devices
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
|
pass
|
||||||
|
return True
|
||||||
@@ -17,6 +17,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|||||||
|
|
||||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
|
||||||
LOG = get_logger("axolotl")
|
LOG = get_logger("axolotl")
|
||||||
@@ -461,6 +462,9 @@ def setup_fsdp_envs(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def prepare_optim_env(cfg):
|
def prepare_optim_env(cfg):
|
||||||
|
if not check_cuda_p2p_ib_support():
|
||||||
|
if os.getenv("NCCL_P2P_DISABLE") is None:
|
||||||
|
os.environ["NCCL_P2P_DISABLE"] = "1"
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
|
|||||||
Reference in New Issue
Block a user