improve GPU logging to break out pytorch cache and system mem

This commit is contained in:
Aman Karmani
2023-08-13 01:50:32 +00:00
committed by Aman Gupta Karmani
parent e029ab34ea
commit 7b55fe6419
6 changed files with 32 additions and 14 deletions

View File

@@ -18,7 +18,6 @@ from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
@@ -226,8 +225,6 @@ def train(
LOG.info("Finished preparing dataset. Exiting...")
return
log_gpu_memory_usage(LOG, "baseline", cfg.device)
# Load the model and tokenizer
LOG.info("loading model and (optionally) peft_config...")
model, peft_config = load_model(cfg, tokenizer)

View File

@@ -4,13 +4,23 @@ import pynvml
import torch
def gpu_memory_usage(device):
def gpu_memory_usage(device=0):
return torch.cuda.memory_allocated(device) / 1024.0**3
def gpu_memory_usage_all(device=0):
usage = torch.cuda.memory_allocated(device) / 1024.0**3
reserved = torch.cuda.memory_reserved(device) / 1024.0**3
smi = gpu_memory_usage_smi(device)
return usage, reserved - usage, max(0, smi - reserved)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
device = device.index
if isinstance(device, str) and device.startswith("cuda:"):
device = int(device[5:])
# NB torch.cuda.memory_usage returns zero so we use lower level api
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
@@ -18,6 +28,13 @@ def gpu_memory_usage(device):
def log_gpu_memory_usage(log, msg, device):
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"GPU memory usage {msg}: {gpu_memory_usage(device):.03f} GB", stacklevel=2
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -74,10 +74,10 @@ class SaveBetterTransformerModelCallback(
return control
class PrintGPUStatsCallback(
class GPUStatsCallback(
TrainerCallback
): # pylint: disable=too-few-public-methods disable=unused-argument
"""Callback to print GPU utilization"""
"""Callback to track GPU utilization"""
def __init__(self, cfg):
self.cfg = cfg
@@ -90,7 +90,7 @@ class PrintGPUStatsCallback(
control: TrainerControl,
**kwargs,
):
if not self.logged:
if not self.logged and state.global_step > 1:
log_gpu_memory_usage(LOG, "while training", self.cfg.device)
self.logged = True
return control

View File

@@ -5,6 +5,8 @@ import os
import torch
from axolotl.utils.bench import log_gpu_memory_usage
LOG = logging.getLogger("axolotl")
@@ -54,6 +56,8 @@ def normalize_config(cfg):
else:
torch.backends.cuda.matmul.allow_tf32 = cfg.tf32 or False
log_gpu_memory_usage(LOG, "baseline", cfg.device)
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:

View File

@@ -381,9 +381,6 @@ def load_model(
module.scales = module.scales.half()
module.bias = module.bias.half()
if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after adapters", model.device)
if (
torch.cuda.device_count() > 1
and int(os.getenv("WORLD_SIZE", "1")) > 1
@@ -406,6 +403,9 @@ def load_model(
if cfg.flash_optimum:
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
log_gpu_memory_usage(LOG, "after adapters", model.device)
# TODO resume_from_checkpoint handling
return model, lora_config

View File

@@ -22,7 +22,7 @@ from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import (
PrintGPUStatsCallback,
GPUStatsCallback,
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
@@ -555,7 +555,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = []
callbacks.append(PrintGPUStatsCallback(cfg))
callbacks.append(GPUStatsCallback(cfg))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(