E2e device cuda (#575)
* use torch.cuda.current_device() instead of local_rank * ignore NVML errors for gpu stats * llama lora packing e2e tests
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from pynvml.nvml import NVMLError
|
||||
|
||||
|
||||
def gpu_memory_usage(device=0):
|
||||
@@ -20,11 +21,13 @@ def gpu_memory_usage_smi(device=0):
|
||||
device = device.index
|
||||
if isinstance(device, str) and device.startswith("cuda:"):
|
||||
device = int(device[5:])
|
||||
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
try:
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
||||
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
||||
return info.used / 1024.0**3
|
||||
except NVMLError:
|
||||
return 0.0
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
|
||||
@@ -29,7 +29,7 @@ def choose_device(cfg):
|
||||
cfg.device_map = "auto"
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": cfg.local_rank}
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
else:
|
||||
cfg.device_map = {"": cfg.device}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user